Skip to content

Commit

Permalink
[Misc][LoRA] Abstract PunicaWrapper (vllm-project#10955)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored and weilong.yu committed Dec 13, 2024
1 parent 0ad90dd commit a60906b
Show file tree
Hide file tree
Showing 9 changed files with 1,058 additions and 749 deletions.
49 changes: 33 additions & 16 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# yapf: enable
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
PackedLoRALayerWeights)
from vllm.lora.punica import PunicaWrapper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -48,11 +48,12 @@
torch.float32: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
CUDA_DEVICES = [
# TODO: Modify this based on platform
DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

# We will launch different triton kernels between the prefill and decode
#For GPU, we will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]

Expand Down Expand Up @@ -192,9 +193,18 @@ def create_random_inputs(
return inputs, index_mapping, prompt_mapping


def check_punica_wrapper(punica_wrapper) -> bool:
if current_platform.is_cuda_alike():
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

return type(punica_wrapper) is PunicaWrapperGPU
else:
return False


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
Expand All @@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:

torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
Expand Down Expand Up @@ -296,7 +307,7 @@ def create_random_embedding_layer():
# @pytest.mark.skip(
# reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
Expand All @@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
Expand Down Expand Up @@ -432,7 +444,7 @@ def create_random_embedding_layer():

@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
@pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
Expand All @@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
Expand Down Expand Up @@ -563,15 +576,16 @@ def _pretest():

@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_replicated(dist_init, num_loras, device, stage,
bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
Expand Down Expand Up @@ -675,15 +689,16 @@ def create_random_linear_replicated_layer():
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
Expand Down Expand Up @@ -797,15 +812,16 @@ def create_random_linear_parallel_layer():
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
Expand Down Expand Up @@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seed = 0
current_platform.seed_everything(seed)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
Expand Down
7 changes: 3 additions & 4 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.distributed.utils import divide
from vllm.lora.punica import PunicaWrapper
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
Expand All @@ -33,7 +32,7 @@
VocabParallelEmbedding)

if TYPE_CHECKING:
pass
from vllm.lora.punica_wrapper import PunicaWrapperBase


def _get_lora_device(base_layer: nn.Module) -> torch.device:
Expand Down Expand Up @@ -115,9 +114,9 @@ def set_lora(

def set_mapping(
self,
punica_wrapper: PunicaWrapper,
punica_wrapper,
):
self.punica_wrapper: PunicaWrapper = punica_wrapper
self.punica_wrapper: PunicaWrapperBase = punica_wrapper

@classmethod
def can_replace_layer(
Expand Down
8 changes: 4 additions & 4 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LinearScalingRotaryEmbeddingWithLora,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica import PunicaWrapper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule)
Expand Down Expand Up @@ -331,9 +331,9 @@ def __init__(
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device)
self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
Expand Down
Loading

0 comments on commit a60906b

Please sign in to comment.