Skip to content

Commit

Permalink
[Misc][LoRA] Replace hardcoded cuda device with configurable argument (
Browse files Browse the repository at this point in the history
…vllm-project#10223)

Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored and weilong.yu committed Dec 13, 2024
1 parent e39d0e8 commit 7a2db87
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 80 deletions.
56 changes: 36 additions & 20 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

# We will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]
Expand Down Expand Up @@ -120,11 +121,12 @@ def populate_loras(
subloras: List[LoRALayerWeights] = []
sublora_len = layer_weights.shape[0] // repeats
for i in range(repeats):
sublora = DummyLoRAManager().init_random_lora(
module_name=f"fake_{i}",
weight=layer_weights,
generate_embeddings_tensor=generate_embeddings_tensor,
)
sublora = DummyLoRAManager(
layer_weights.device).init_random_lora(
module_name=f"fake_{i}",
weight=layer_weights,
generate_embeddings_tensor=generate_embeddings_tensor,
)
sublora.lora_b = sublora.lora_b[:, (sublora_len *
i):(sublora_len * (i + 1))]
sublora.optimize()
Expand Down Expand Up @@ -152,6 +154,7 @@ def create_random_inputs(
input_size: Tuple[int, ...],
input_range: Tuple[float, float],
input_type: torch.dtype = torch.int,
device: torch.device = "cuda"
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
"""Creates random inputs.
Expand All @@ -173,10 +176,14 @@ def create_random_inputs(
for _ in range(num_inputs):
if input_type == torch.int:
inputs.append(
torch.randint(low=int(low), high=int(high), size=input_size))
torch.randint(low=int(low),
high=int(high),
size=input_size,
device=device))
else:
inputs.append(
torch.rand(size=input_size, dtype=input_type) * high + low)
torch.rand(size=input_size, dtype=input_type, device=device) *
high + low)

lora_id = random.choice(active_lora_ids)
index_mapping += [lora_id] * input_size[0]
Expand All @@ -191,6 +198,10 @@ def create_random_inputs(
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
# device, see: https://github.com/triton-lang/triton/issues/2925
# Same below.
torch.cuda.set_device(device)

torch.set_default_device(device)
max_loras = 8
Expand Down Expand Up @@ -225,7 +236,7 @@ def create_random_embedding_layer():
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, vocab_size),
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
Expand Down Expand Up @@ -263,7 +274,7 @@ def create_random_embedding_layer():
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, vocab_size),
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
Expand Down Expand Up @@ -291,6 +302,7 @@ def create_random_embedding_layer():
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
vocab_size, stage) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
Expand Down Expand Up @@ -345,7 +357,7 @@ def create_random_embedding_layer():
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, vocab_size),
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
Expand Down Expand Up @@ -400,7 +412,7 @@ def create_random_embedding_layer():
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, vocab_size),
)
device=device)
original_inputs = deepcopy(inputs)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
Expand All @@ -426,6 +438,7 @@ def create_random_embedding_layer():
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
stage) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
Expand Down Expand Up @@ -471,7 +484,7 @@ def _pretest():
input_size=(1, 1024),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
Expand Down Expand Up @@ -520,7 +533,7 @@ def _pretest():
input_size=(1, 1024),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
Expand Down Expand Up @@ -554,6 +567,7 @@ def _pretest():
@pytest.mark.parametrize("stage", STAGES)
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
Expand Down Expand Up @@ -592,7 +606,7 @@ def create_random_linear_replicated_layer():
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
Expand Down Expand Up @@ -631,7 +645,7 @@ def create_random_linear_replicated_layer():
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
Expand All @@ -658,6 +672,7 @@ def create_random_linear_replicated_layer():
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
Expand Down Expand Up @@ -706,7 +721,7 @@ def create_random_linear_parallel_layer():
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
Expand Down Expand Up @@ -745,7 +760,7 @@ def create_random_linear_parallel_layer():
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
Expand All @@ -772,6 +787,7 @@ def create_random_linear_parallel_layer():
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
Expand Down Expand Up @@ -842,7 +858,7 @@ class FakeConfig:
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
Expand Down Expand Up @@ -883,7 +899,7 @@ class FakeConfig:
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
Expand Down Expand Up @@ -962,7 +978,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
input_size=(1, max_position),
input_range=(0, lora_config.lora_extra_vocab_size),
input_type=torch.float16,
)
device=device)

lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
long_lora_context = LongContextLoRAContext(list(scaling_factors),
Expand Down
Loading

0 comments on commit 7a2db87

Please sign in to comment.