Skip to content

Commit ccdc490

Browse files
authored
[Core] Change LoRA embedding sharding to support loading methods (vllm-project#5038)
1 parent a31cab7 commit ccdc490

11 files changed

+661
-129
lines changed

.buildkite/test-pipeline.yaml

+2-8
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ steps:
4646
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
4747
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
4848
- pytest -v -s spec_decode/e2e/test_integration_dist.py
49+
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
4950

5051
- label: Distributed Tests (Multiple Groups)
5152
#mirror_hardwares: [amd]
@@ -138,14 +139,7 @@ steps:
138139
num_gpus: 4
139140
# This test runs llama 13B, so it is required to run on 4 GPUs.
140141
commands:
141-
# Temporarily run this way because we cannot clean up GPU mem usage
142-
# for multi GPU tests.
143-
# TODO(sang): Fix it.
144-
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
145-
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
146-
- pytest -v -s lora/test_long_context.py::test_self_consistency
147-
- pytest -v -s lora/test_long_context.py::test_quality
148-
- pytest -v -s lora/test_long_context.py::test_max_len
142+
- pytest -v -s -x lora/test_long_context.py
149143

150144
- label: Tensorizer Test
151145
#mirror_hardwares: [amd]

tests/conftest.py

+21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import contextlib
22
import gc
33
import os
4+
import subprocess
5+
import sys
46
from typing import Any, Dict, List, Optional, Tuple, TypeVar
57

68
import pytest
@@ -522,3 +524,22 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
522524
# To capture vllm log, we should enable propagate=True temporarily
523525
# because caplog depends on logs propagated to the root logger.
524526
yield caplog
527+
528+
529+
@pytest.fixture(scope="session")
530+
def num_gpus_available():
531+
"""Get number of GPUs without initializing the CUDA context
532+
in current process."""
533+
534+
try:
535+
out = subprocess.run([
536+
sys.executable, "-c",
537+
"import torch; print(torch.cuda.device_count())"
538+
],
539+
capture_output=True,
540+
check=True,
541+
text=True)
542+
except subprocess.CalledProcessError as e:
543+
logger.warning("Failed to get number of GPUs.", exc_info=e)
544+
return 0
545+
return int(out.stdout.strip())

tests/lora/conftest.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,24 @@ def cleanup():
4242
ray.shutdown()
4343

4444

45+
@pytest.fixture()
46+
def should_do_global_cleanup_after_test(request) -> bool:
47+
"""Allow subdirectories to skip global cleanup by overriding this fixture.
48+
This can provide a ~10x speedup for non-GPU unit tests since they don't need
49+
to initialize torch.
50+
"""
51+
52+
if request.node.get_closest_marker("skip_global_cleanup"):
53+
return False
54+
55+
return True
56+
57+
4558
@pytest.fixture(autouse=True)
46-
def cleanup_fixture():
59+
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
4760
yield
48-
cleanup()
61+
if should_do_global_cleanup_after_test:
62+
cleanup()
4963

5064

5165
@pytest.fixture

tests/lora/test_layers.py

+217-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from copy import deepcopy
33
from dataclasses import dataclass
44
from typing import Dict, List, Optional, Tuple
5+
from unittest.mock import patch
56

67
import pytest
78
import torch
@@ -32,7 +33,7 @@
3233
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3334
from vllm.model_executor.layers.rotary_embedding import get_rope
3435
from vllm.model_executor.layers.vocab_parallel_embedding import (
35-
ParallelLMHead, VocabParallelEmbedding)
36+
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
3637
from vllm.model_executor.utils import set_random_seed
3738

3839
from .utils import DummyLoRAManager
@@ -427,7 +428,8 @@ def _pretest():
427428
logits_processor = LogitsProcessor(
428429
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
429430
lora_logits_processor = LogitsProcessorWithLoRA(
430-
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
431+
logits_processor, 1024, linear.weight.dtype, linear.weight.device,
432+
None)
431433
lora_logits_processor.create_lora_weights(max_loras, lora_config)
432434

433435
return linear, logits_processor, lora_logits_processor
@@ -867,3 +869,216 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
867869

868870
torch.allclose(ref_q, actual_q)
869871
torch.allclose(ref_k, actual_k)
872+
873+
874+
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
875+
@pytest.mark.parametrize("seed", list(range(256)))
876+
def test_vocab_parallel_embedding_indices(tp_size, seed):
877+
random.seed(seed)
878+
vocab_size = random.randint(4000, 64000)
879+
added_vocab_size = random.randint(0, 1024)
880+
org_vocab_size = vocab_size - added_vocab_size
881+
last_org_vocab_end_index = 0
882+
last_added_vocab_end_index = org_vocab_size
883+
computed_vocab_size = 0
884+
computed_org_vocab_size = 0
885+
computed_added_vocab_size = 0
886+
vocab_size_padded = -1
887+
888+
all_org_tokens = []
889+
all_added_tokens = []
890+
token_ids = []
891+
892+
for tp_rank in range(tp_size):
893+
with patch(
894+
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
895+
return_value=tp_rank
896+
), patch(
897+
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
898+
return_value=tp_size):
899+
vocab_embedding = VocabParallelEmbedding(
900+
vocab_size, 1, org_num_embeddings=org_vocab_size)
901+
vocab_size_padded = vocab_embedding.num_embeddings_padded
902+
shard_indices = vocab_embedding.shard_indices
903+
# Assert that the ranges are contiguous
904+
assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
905+
assert (shard_indices.added_vocab_start_index ==
906+
last_added_vocab_end_index)
907+
908+
# Ensure that we are not exceeding the vocab size
909+
computed_vocab_size += shard_indices.num_elements_padded
910+
computed_org_vocab_size += shard_indices.num_org_elements
911+
computed_added_vocab_size += shard_indices.num_added_elements
912+
913+
# Ensure that the ranges are not overlapping
914+
all_org_tokens.extend(
915+
range(shard_indices.org_vocab_start_index,
916+
shard_indices.org_vocab_end_index))
917+
all_added_tokens.extend(
918+
range(shard_indices.added_vocab_start_index,
919+
shard_indices.added_vocab_end_index))
920+
921+
token_ids.extend(
922+
range(shard_indices.org_vocab_start_index,
923+
shard_indices.org_vocab_end_index))
924+
token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
925+
shard_indices.num_org_elements))
926+
token_ids.extend(
927+
range(shard_indices.added_vocab_start_index,
928+
shard_indices.added_vocab_end_index))
929+
token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
930+
shard_indices.num_added_elements))
931+
932+
last_org_vocab_end_index = shard_indices.org_vocab_end_index
933+
last_added_vocab_end_index = shard_indices.added_vocab_end_index
934+
935+
assert computed_vocab_size == vocab_size_padded
936+
assert computed_org_vocab_size == org_vocab_size
937+
assert computed_added_vocab_size == added_vocab_size
938+
939+
# Ensure that the ranges are not overlapping
940+
assert len(all_org_tokens) == len(set(all_org_tokens))
941+
assert len(all_added_tokens) == len(set(all_added_tokens))
942+
assert not set(all_org_tokens).intersection(set(all_added_tokens))
943+
944+
token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
945+
reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
946+
assert reindex_mapping is not None or tp_size == 1
947+
if reindex_mapping is not None:
948+
reindexed_token_ids = token_ids_tensor[reindex_mapping]
949+
expected = torch.tensor(list(range(0, vocab_size)))
950+
assert reindexed_token_ids[:vocab_size].equal(expected)
951+
assert torch.all(reindexed_token_ids[vocab_size:] == -1)
952+
953+
954+
def test_get_masked_input_and_mask():
955+
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
956+
957+
# base tp 1 case, no padding
958+
modified_x, _ = get_masked_input_and_mask(x,
959+
org_vocab_start_index=0,
960+
org_vocab_end_index=8,
961+
added_vocab_start_index=8,
962+
added_vocab_end_index=12,
963+
num_org_vocab_padding=0)
964+
assert torch.equal(x, modified_x)
965+
966+
# tp 2 case, no padding
967+
modified_x_rank_0, _ = get_masked_input_and_mask(x,
968+
org_vocab_start_index=0,
969+
org_vocab_end_index=4,
970+
added_vocab_start_index=8,
971+
added_vocab_end_index=10,
972+
num_org_vocab_padding=0)
973+
modified_x_rank_1, _ = get_masked_input_and_mask(
974+
x,
975+
org_vocab_start_index=4,
976+
org_vocab_end_index=8,
977+
added_vocab_start_index=10,
978+
added_vocab_end_index=12,
979+
num_org_vocab_padding=0)
980+
assert torch.equal(modified_x_rank_0,
981+
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
982+
assert torch.equal(modified_x_rank_1,
983+
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))
984+
985+
# tp 4 case, no padding
986+
modified_x_rank_0, _ = get_masked_input_and_mask(x,
987+
org_vocab_start_index=0,
988+
org_vocab_end_index=2,
989+
added_vocab_start_index=8,
990+
added_vocab_end_index=9,
991+
num_org_vocab_padding=0)
992+
modified_x_rank_1, _ = get_masked_input_and_mask(x,
993+
org_vocab_start_index=2,
994+
org_vocab_end_index=4,
995+
added_vocab_start_index=9,
996+
added_vocab_end_index=10,
997+
num_org_vocab_padding=0)
998+
modified_x_rank_2, _ = get_masked_input_and_mask(
999+
x,
1000+
org_vocab_start_index=4,
1001+
org_vocab_end_index=6,
1002+
added_vocab_start_index=10,
1003+
added_vocab_end_index=11,
1004+
num_org_vocab_padding=0)
1005+
modified_x_rank_3, _ = get_masked_input_and_mask(
1006+
x,
1007+
org_vocab_start_index=6,
1008+
org_vocab_end_index=8,
1009+
added_vocab_start_index=11,
1010+
added_vocab_end_index=12,
1011+
num_org_vocab_padding=0)
1012+
assert torch.equal(modified_x_rank_0,
1013+
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
1014+
assert torch.equal(modified_x_rank_1,
1015+
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
1016+
assert torch.equal(modified_x_rank_2,
1017+
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
1018+
assert torch.equal(modified_x_rank_3,
1019+
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))
1020+
1021+
# base tp 1 case, with padding
1022+
modified_x, _ = get_masked_input_and_mask(x,
1023+
org_vocab_start_index=0,
1024+
org_vocab_end_index=8,
1025+
added_vocab_start_index=8,
1026+
added_vocab_end_index=12,
1027+
num_org_vocab_padding=2)
1028+
assert torch.equal(modified_x,
1029+
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))
1030+
1031+
# tp 2 case, with padding
1032+
modified_x_rank_0, _ = get_masked_input_and_mask(x,
1033+
org_vocab_start_index=0,
1034+
org_vocab_end_index=4,
1035+
added_vocab_start_index=8,
1036+
added_vocab_end_index=10,
1037+
num_org_vocab_padding=2)
1038+
modified_x_rank_1, _ = get_masked_input_and_mask(
1039+
x,
1040+
org_vocab_start_index=4,
1041+
org_vocab_end_index=8,
1042+
added_vocab_start_index=10,
1043+
added_vocab_end_index=12,
1044+
num_org_vocab_padding=2)
1045+
assert torch.equal(modified_x_rank_0,
1046+
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
1047+
assert torch.equal(modified_x_rank_1,
1048+
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))
1049+
1050+
# tp 4 case, with padding
1051+
modified_x_rank_0, _ = get_masked_input_and_mask(x,
1052+
org_vocab_start_index=0,
1053+
org_vocab_end_index=2,
1054+
added_vocab_start_index=8,
1055+
added_vocab_end_index=9,
1056+
num_org_vocab_padding=2)
1057+
modified_x_rank_1, _ = get_masked_input_and_mask(x,
1058+
org_vocab_start_index=2,
1059+
org_vocab_end_index=4,
1060+
added_vocab_start_index=9,
1061+
added_vocab_end_index=10,
1062+
num_org_vocab_padding=2)
1063+
modified_x_rank_2, _ = get_masked_input_and_mask(
1064+
x,
1065+
org_vocab_start_index=4,
1066+
org_vocab_end_index=6,
1067+
added_vocab_start_index=10,
1068+
added_vocab_end_index=11,
1069+
num_org_vocab_padding=2)
1070+
modified_x_rank_3, _ = get_masked_input_and_mask(
1071+
x,
1072+
org_vocab_start_index=6,
1073+
org_vocab_end_index=8,
1074+
added_vocab_start_index=11,
1075+
added_vocab_end_index=12,
1076+
num_org_vocab_padding=2)
1077+
assert torch.equal(modified_x_rank_0,
1078+
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
1079+
assert torch.equal(modified_x_rank_1,
1080+
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
1081+
assert torch.equal(modified_x_rank_2,
1082+
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
1083+
assert torch.equal(modified_x_rank_3,
1084+
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))

tests/lora/test_llama.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,10 @@ def do_sample(llm, lora_path: str, lora_id: int):
3636
return generated_texts
3737

3838

39-
@pytest.mark.parametrize("tp_size", [1])
40-
def test_llama_lora(sql_lora_files, tp_size):
41-
# Cannot use as it will initialize torch.cuda too early...
42-
# if torch.cuda.device_count() < tp_size:
43-
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
39+
@pytest.mark.parametrize("tp_size", [1, 2, 4])
40+
def test_llama_lora(sql_lora_files, tp_size, num_gpus_available):
41+
if num_gpus_available < tp_size:
42+
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
4443

4544
llm = vllm.LLM(MODEL_PATH,
4645
enable_lora=True,
@@ -80,11 +79,9 @@ def test_llama_lora(sql_lora_files, tp_size):
8079
print("removing lora")
8180

8281

83-
@pytest.mark.skip("Requires multiple GPUs")
84-
def test_llama_tensor_parallel_equality(sql_lora_files):
85-
# Cannot use as it will initialize torch.cuda too early...
86-
# if torch.cuda.device_count() < 4:
87-
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
82+
def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
83+
if num_gpus_available < 4:
84+
pytest.skip("Not enough GPUs for tensor parallelism 4")
8885

8986
llm_tp1 = vllm.LLM(MODEL_PATH,
9087
enable_lora=True,

0 commit comments

Comments
 (0)