From 16144cd2ab06a7440358c0c298e86066bdde2772 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 28 Nov 2024 09:00:54 +0000 Subject: [PATCH 1/8] Init Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 16 +- vllm/lora/layers.py | 15 +- vllm/lora/models.py | 8 +- vllm/lora/punica.py | 611 ----------------------------------- vllm/lora/punica_base.py | 287 ++++++++++++++++ vllm/lora/punica_gpu.py | 311 ++++++++++++++++++ vllm/lora/punica_selector.py | 11 + vllm/lora/utils.py | 159 ++++++++- 8 files changed, 786 insertions(+), 632 deletions(-) delete mode 100644 vllm/lora/punica.py create mode 100644 vllm/lora/punica_base.py create mode 100644 vllm/lora/punica_gpu.py create mode 100644 vllm/lora/punica_selector.py diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 15e576cb065c7..5dda1870e67d1 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -28,7 +28,7 @@ # yapf: enable from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, PackedLoRALayerWeights) -from vllm.lora.punica import PunicaWrapper +from vllm.lora.punica_gpu import PunicaWrapperGPU from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -205,7 +205,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = PunicaWrapperGPU(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -305,7 +305,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = PunicaWrapperGPU(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -441,7 +441,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = PunicaWrapperGPU(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -569,7 +569,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None: torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = PunicaWrapperGPU(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -674,7 +674,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = PunicaWrapperGPU(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -789,7 +789,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = PunicaWrapperGPU(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -941,7 +941,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, seed = 0 current_platform.seed_everything(seed) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = PunicaWrapperGPU(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3701988ff692f..4391fd335ecf0 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -17,7 +17,6 @@ tensor_model_parallel_all_reduce, tensor_model_parallel_gather) from vllm.distributed.utils import divide -from vllm.lora.punica import PunicaWrapper from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -30,7 +29,7 @@ VocabParallelEmbedding) if TYPE_CHECKING: - pass + from vllm.lora.punica_base import PunicaWrapperBase def _get_lora_device(base_layer: nn.Module) -> torch.device: @@ -169,9 +168,9 @@ def set_lora( def set_mapping( self, - punica_wrapper: PunicaWrapper, + punica_wrapper, ): - self.punica_wrapper: PunicaWrapper = punica_wrapper + self.punica_wrapper: PunicaWrapperBase = punica_wrapper @classmethod def can_replace_layer( @@ -308,10 +307,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) # Embedding layer only need expand op - self.punica_wrapper.add_expand(full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) + self.punica_wrapper.add_lora_embedding(full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) return full_output.view_as(full_output_org) @classmethod diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 2ffefe61427e3..18ef91f55c10e 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -21,7 +21,7 @@ LinearScalingRotaryEmbeddingWithLora, LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.punica import PunicaWrapper +from vllm.lora.punica_selector import get_punica_wrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) @@ -331,9 +331,9 @@ def __init__( self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.long_lora_context: Optional[LongContextLoRAContext] = None - self.punica_wrapper = PunicaWrapper(max_num_batched_tokens, - max_batches=self.max_num_seqs, - device=self.device) + self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens, + max_batches=self.max_num_seqs, + device=self.device) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py deleted file mode 100644 index 082041f390750..0000000000000 --- a/vllm/lora/punica.py +++ /dev/null @@ -1,611 +0,0 @@ -""" -Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. -https://arxiv.org/abs/2310.18547 -""" - -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union - -import torch - -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.lora.ops.bgmv_expand import bgmv_expand - from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice - from vllm.lora.ops.sgmv_shrink import sgmv_shrink - -if TYPE_CHECKING: - # avoid circuit import - from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext - - -def compute_meta( - token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: - """ - Get the information required for the sgmv kernel. With the features: - 1. If consecutive requests in the batch use the same LoRA, this function - will combine them into a single request, improving sgmv kernel inference - performance. - 2. At the beginning of each prefill stage inference, recalculations are - needed based on the input, but only once. - """ - - lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( - token_lora_tensor, return_counts=True) - cum_result = torch.cumsum(seq_length_tensor, dim=0) - b_seq_start_tensor = torch.zeros_like(seq_length_tensor) - b_seq_start_tensor[1:].copy_(cum_result[:-1]) - max_length = seq_length_tensor.max().item() - token_nums = seq_length_tensor.sum().item() - batch_size = lora_indices_tensor.size(0) - no_lora = False - # -1 means no lora should be applied. Use `no_lora` to determine whether - # the current step requires LoRA. If LoRA is not needed, the prefill stage - # does not need to launch the triton kernel, which can improve performance - if batch_size == 1 and lora_indices_tensor == -1: - no_lora = True - return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, no_lora) - - -# TODO see if this can be vectorized -def convert_mapping( - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - device: torch.device, - long_lora_context: Optional["LongContextLoRAContext"] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: - """Converts LoRAMapping to index tensors. - - Args: - mapping: LoRAMapping mapping rows in a batch to LoRA ids. - lora_index_to_id: List mapping LoRA ids to LoRA indices. - max_loras: Maximum number of LoRAs. - vocab_size: Model vocab size. - extra_vocab_size: Extra vocab size each LoRA can have. - long_lora_context: Passed if there are long context lora in a batch. - - Returns: - A tuple of tensors: - base_indices: Tensor of shape [batch_size] mapping batch rows to - LoRA indices. - sampler_indices: Tensor of shape [batch_size] mapping requests to - LoRA indices for sampler. For generation, this will be the - same as base_indicies. For prefill, this will map requests - to LoRA indices. - sampler_indices_padded: Tensor of shape [batch_size] mapping - requests to LoRA indices for sampler with padding. - Same as sampler_indicies, but -1 is replaced with - max_loras. - embeddings_indices: Tensor of shape [2, batch_size] mapping - requests to embedding indices. First row is for embeddings - added by the LoRAs, second row is for the LoRA.lora_a - embeddings. - long_lora_indices: Tensor of shape [batch_size] mapping - requests to RoPE offsets and rot dims for long LoRAs. - None if long context lora doesn't exist. - indices_len: List of lengths of the above tensors. It contains - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices). - """ - index_mapping_indices: List[int] = list(mapping.index_mapping).copy() - embedding_indices = index_mapping_indices.copy() - lora_indices = index_mapping_indices.copy() - long_lora_offsets: Optional[torch.Tensor] = None - if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) - prompt_mapping: List[int] = [ - lora_index_to_id.index(x) if x > 0 else -1 - for x in mapping.prompt_mapping - ] - lora_idx = None - for i in range(len(index_mapping_indices)): - # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) - embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 - lora_indices[i] = lora_idx - if long_lora_context: - assert long_lora_offsets is not None - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset - - indices_list: List[Union[List[int], torch.Tensor]] = [ - index_mapping_indices, - lora_indices, - embedding_indices, - ] - if long_lora_context: - assert long_lora_offsets is not None - indices_list.append(long_lora_offsets) - indices = torch.tensor(indices_list, dtype=torch.long, device=device) - prompt_mapping_tensor = torch.tensor(prompt_mapping, - dtype=torch.long, - device=device) - embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size), - ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 - base_indices = indices[1] - sampler_indices = prompt_mapping_tensor - sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 - sampler_indices_padded = torch.arange( - 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( - sampler_indices_padded * len(sampler_indices_padded)) - long_lora_indices = None - long_lora_indices_len: Optional[int] = None - if long_lora_context: - long_lora_indices = indices[3] - long_lora_indices_len = long_lora_indices.shape[-1] - # Contain length of indices tensors. Used to index into each tensor. - indices_len = [ - base_indices.shape[-1], - sampler_indices.shape[-1], - sampler_indices_padded.shape[-1], - embeddings_indices.shape[-1], - ] - if long_lora_indices_len is not None: - indices_len.append(long_lora_indices_len) - else: - # If long_lora doesn't exist,append None - indices_len.append(None) - - return ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - long_lora_indices, - indices_len, - ) - - -class PunicaWrapper: - """ - PunicaWrapper is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for - Multi-LoRA, and to provide the interface for the punica kernel. - """ - - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str]): - self._token_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices_padded = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._embeddings_indices = torch.empty(2, - max_num_batched_tokens, - dtype=torch.long, - device=device) - self._long_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - - # 5 is the number of indicies tensors. - # base_indices, sampler_indices, sampler_indices_padded, - # embeddings_indices,long_lora_indices - self.indices_len: List[Optional[int]] = [None] * 5 - # these attributes are the information required for sgmv kernel - self._seq_start_locs = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._seq_lengths = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._lora_indices_per_batch = torch.empty(max_batches, - dtype=torch.long, - device=device) - self.device: torch.device = device - self.max_length: int = 0 - self.token_nums: int = 0 - self.batch_size: int = -1 - self.is_prefill = False - self.no_lora = False - - def update_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - ): - - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) - if mapping.is_prefill: - # Update metadata required for prefill-related operators. - self._update_prefill_metada(self.token_lora_indices) - self.is_prefill = True - else: - self.is_prefill = False - - def _update_base_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - ): - ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - long_lora_offsets_tensor, - indices_len, - ) = convert_mapping( - mapping, - lora_index_to_id, - max_loras, - vocab_size, - extra_vocab_size, - self.device, - long_lora_context, - ) - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) - self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) - self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) - if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) - else: - self._long_lora_indices.zero_() - self.indices_len[:] = indices_len - - def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: - - (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, - no_lora) = compute_meta(token_lora_tensor) - - self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( - b_seq_start_tensor) - self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) - self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( - lora_indices_tensor) - self.batch_size = batch_size - self.max_length = max_length - self.token_nums = token_nums - self.no_lora = no_lora - - @property - def prefill_metadata( - self - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: - """ - This property provides a convenient way to access the necessary - metadata for prefill-related kernel computations. - 1. seq_start_locs: Tensor of sequence start positions. - 2. seq_lengths: Tensor of sequence lengths. - 3. lora_indices_per_batch: Tensor of lora indices, and an index of - -1 means no lora should be applied. - 4. batch_size: Batch size after clustering identical lora indices. - 5. max_length: The maximum sequence length in the batch. - 6. token_nums: The token numbers in the batch. - """ - return (self._seq_start_locs[:self.batch_size], - self._seq_lengths[:self.batch_size], - self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length, self.token_nums) - - @property - def token_lora_indices(self) -> torch.Tensor: - """ - This property provides the lora indices corresponding to each token - in the batch. An index of -1 means no lora should be applied. - """ - token_lora_len = self.indices_len[0] - return self._token_lora_indices[:token_lora_len] - - @property - def sampler_indices(self) -> torch.Tensor: - """ - This property is used to access the lora indices specifically for - LogitsProcessorWithLoRA. - """ - sampler_indices_len = self.indices_len[1] - return self._sampler_indices[:sampler_indices_len] - - @property - def sampler_indices_padded(self) -> torch.Tensor: - """ - This property provides access to padded sampler indices. - """ - indices_padded_len = self.indices_len[2] - return self._sampler_indices_padded[:indices_padded_len] - - @property - def embeddings_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA. - """ - embeddings_indices_len = self.indices_len[3] - return self._embeddings_indices[:, :embeddings_indices_len] - - @property - def long_lora_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLora. - """ - long_lora_len = self.indices_len[4] - return self._long_lora_indices[:long_lora_len] - - def shrink_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_shrink( - x, - w_t_all, - y, - *self.prefill_metadata, - scale, - ) - - def shrink_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - - def expand_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_input: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_expand( - x, - w_t_all, - y, - *self.prefill_metadata, - add_input, - ) - - def expand_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_input: bool, - ): - bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) - - def expand_slice_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_expand_slice( - x, - w_t_all, - y, - *self.prefill_metadata, - y_offset, - y_slice_size, - add_input, - ) - - def expand_slice_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool, - ): - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_input) - - def add_shrink( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the - GEMM of lora'a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the shrink_decode function - should be called. - """ - shrink_fun: Callable = (self.shrink_prefill - if self.is_prefill else self.shrink_decode) - shrink_fun(y, x, w_t_all, scale) - - def add_expand( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_input: bool = True, - ): - """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the - GEMM of lora'b. - When `is_prefill` is true, it indicates that it is currently the - prefill stage, and the `expand_prefill` function should be called. - Otherwise, it is the decode stage, and the expand_decode function - should be called. - """ - - expand_fun: Callable = (self.expand_prefill - if self.is_prefill else self.expand_decode) - expand_fun(y, x, w_t_all, add_input) - - def add_expand_slice(self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool = True): - """ - Similar to `add_expand` - """ - - expand_slice_fun: Callable = (self.expand_slice_prefill - if self.is_prefill else - self.expand_slice_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) - - def add_lora(self, - y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - scale: float, - y_offset: Optional[int] = None, - y_slice_size: Optional[int] = None, - *, - buffer: Optional[torch.Tensor] = None) -> None: - """ - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) - Args: - y (torch.Tensor): Output tensor. Will be changed in-place. - x (torch.Tensor): Input tensor - wa_t_all (torch.Tensor): lora_a's weight - wb_t_all (torch.Tensor): lora_b's weight - scale (float): Scaling factor. - y_offset (Optional[int], optional): Offset to apply to the starting - column of y. - y_slice_size (Optional[int], optional): Size of the y column slice. - buffer (Optional[torch.Tensor], optional): Defaults to None. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - x = x.view(-1, x.shape[-1]) - r = wb_t_all.size(-1) - if buffer is None: - # We set the buffer to be float32 by default ,refer to: - # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - - self.add_shrink(buffer, x, wa_t_all, scale) - if y_offset is None and y_slice_size is None: - self.add_expand(y, buffer, wb_t_all, add_input=True) - else: - self.add_expand_slice(y, - buffer, - wb_t_all, - y_offset, - y_slice_size, - add_input=True) - y = y.view_as(y_org) - - def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - scale: float, - output_slices: Tuple[int, ...]) -> None: - """ - Applies lora to each input. Similar to add_lora, This method is - used for layers that are composed of multiple sublayers - (slices) packed together. - """ - y_org = y - x = x.view(-1, x.shape[-1]) - y = y.view(-1, y.shape[-1]) - offset_left = 0 - # TODO fuse these kernels - for slice_idx in range(len(output_slices)): - self.add_lora(y, x, lora_a_stacked[slice_idx], - lora_b_stacked[slice_idx], scale, offset_left, - output_slices[slice_idx]) - offset_left += output_slices[slice_idx] - - y = y.view_as(y_org) - - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None) -> None: - """ - LogitsProcessorWithLoRA always using bgmv - """ - y_org = y - y = y.view(-1, y.shape[-1]) - x = x.view(-1, x.shape[-1]) - r = wb_t_all.size(-1) - if buffer is None: - # We set the buffer to be float32 by default ,refer to: - # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - - bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True) - y = y.view_as(y_org) diff --git a/vllm/lora/punica_base.py b/vllm/lora/punica_base.py new file mode 100644 index 0000000000000..a0ab76061b0d5 --- /dev/null +++ b/vllm/lora/punica_base.py @@ -0,0 +1,287 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from vllm.lora.utils import compute_meta, convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +class PunicaWrapperBase(ABC): + """ + PunicaWrapper is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica kernel. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str]): + self._token_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._embeddings_indices = torch.empty(2, + max_num_batched_tokens, + dtype=torch.long, + device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + + # 5 is the number of indicies tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: List[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._seq_lengths = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._lora_indices_per_batch = torch.empty(max_batches, + dtype=torch.long, + device=device) + self.device: torch.device = device + self.max_length: int = 0 + self.token_nums: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + self.device, + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) + + self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + b_seq_start_tensor) + self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.token_nums = token_nums + self.no_lora = no_lora + + @property + def prefill_metadata( + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. + """ + return (self._seq_start_locs[:self.batch_size], + self._seq_lengths[:self.batch_size], + self._lora_indices_per_batch[:self.batch_size], + self.batch_size, self.max_length, self.token_nums) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA. + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLora. + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + # TODO: we also need to consider lora bias + # @abstractmethod + # def add_bias(self): + # raise NotImplementedError + + # @abstractmethod + # def add_bias_slice(self): + # raise NotImplementedError + + @abstractmethod + def add_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float, **kwarg): + raise NotImplementedError + + @abstractmethod + def add_expand(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool = True, + **kwarg): + raise NotImplementedError + + @abstractmethod + def add_expand_slice(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True, + **kwarg): + raise NotImplementedError + + @abstractmethod + def add_lora(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale: float, + y_offset: Optional[int] = None, + y_slice_size: Optional[int] = None, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + raise NotImplementedError + + @abstractmethod + def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + scale: float, output_slices: Tuple[int, ...], + **kwarg) -> None: + raise NotImplementedError + + @abstractmethod + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool = True, + **kwarg): + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwarg) -> None: + + raise NotImplementedError diff --git a/vllm/lora/punica_gpu.py b/vllm/lora/punica_gpu.py new file mode 100644 index 0000000000000..bf9ee6860c465 --- /dev/null +++ b/vllm/lora/punica_gpu.py @@ -0,0 +1,311 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Callable, Optional, Tuple, Union, final + +import torch + +from vllm.lora.punica_base import PunicaWrapperBase +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.bgmv_expand import bgmv_expand + from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.sgmv_expand import sgmv_expand + from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.sgmv_shrink import sgmv_shrink + + +@final +class PunicaWrapperGPU(PunicaWrapperBase): + """ + PunicaWrapper is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica kernel. + """ + + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + ): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + # No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_input, + ) + + def expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) + + def expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + # No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_input, + ) + + def expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + bgmv_expand_slice( + x, + w_t_all, + y, + self.token_lora_indices, + y_offset, + y_slice_size, + add_input, + ) + + def add_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float, **kwarg): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the shrink_decode function + should be called. + """ + shrink_fun: Callable = (self.shrink_prefill + if self.is_prefill else self.shrink_decode) + shrink_fun(y, x, w_t_all, scale) + + def add_expand(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool = True, + **kwarg): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'b. + When `is_prefill` is true, it indicates that it is currently the + prefill stage, and the `expand_prefill` function should be called. + Otherwise, it is the decode stage, and the expand_decode function + should be called. + """ + + expand_fun: Callable = (self.expand_prefill + if self.is_prefill else self.expand_decode) + expand_fun(y, x, w_t_all, add_input) + + def add_expand_slice(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True, + **kwarg): + """ + Similar to `add_expand` + """ + + expand_slice_fun: Callable = (self.expand_slice_prefill + if self.is_prefill else + self.expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + + def add_lora(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale: float, + y_offset: Optional[int] = None, + y_slice_size: Optional[int] = None, + *, + buffer: Optional[torch.Tensor] = None, + **kwarg) -> None: + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + wa_t_all (torch.Tensor): lora_a's weight + wb_t_all (torch.Tensor): lora_b's weight + scale (float): Scaling factor. + y_offset (Optional[int], optional): Offset to apply to the starting + column of y. + y_slice_size (Optional[int], optional): Size of the y column slice. + buffer (Optional[torch.Tensor], optional): Defaults to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + self.add_shrink(buffer, x, wa_t_all, scale) + if y_offset is None and y_slice_size is None: + self.add_expand(y, buffer, wb_t_all, add_input=True) + else: + self.add_expand_slice(y, + buffer, + wb_t_all, + y_offset, + y_slice_size, + add_input=True) + y = y.view_as(y_org) + + def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + scale: float, output_slices: Tuple[int, ...], + **kwarg) -> None: + """ + Applies lora to each input. Similar to add_lora, This method is + used for layers that are composed of multiple sublayers + (slices) packed together. + """ + y_org = y + x = x.view(-1, x.shape[-1]) + y = y.view(-1, y.shape[-1]) + offset_left = 0 + # TODO fuse these kernels + for slice_idx in range(len(output_slices)): + self.add_lora( + y, + x, + lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], + scale, + offset_left, + output_slices[slice_idx], + ) + offset_left += output_slices[slice_idx] + + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool = True, + **kwarg): + """ + VocabParallelEmbeddingWithLoRA only need expand op + """ + + expand_fun: Callable = (self.expand_prefill + if self.is_prefill else self.expand_decode) + expand_fun(y, x, w_t_all, add_input) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwarg) -> None: + """ + LogitsProcessorWithLoRA always using bgmv + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/lora/punica_selector.py b/vllm/lora/punica_selector.py new file mode 100644 index 0000000000000..76cbf54b98fff --- /dev/null +++ b/vllm/lora/punica_selector.py @@ -0,0 +1,11 @@ +from vllm.lora.punica_base import PunicaWrapperBase +from vllm.platforms import current_platform + + +def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: + if current_platform.is_cuda() or current_platform.is_rocm(): + from vllm.lora.punica_gpu import PunicaWrapperGPU + + return PunicaWrapperGPU(*args, **kwargs) + else: + raise NotImplementedError diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 5876494ce2824..a9492e5fb2fae 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,8 +1,9 @@ import os import re -from typing import List, Optional, Set, Tuple, Type, Union +from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type, Union import huggingface_hub +import torch from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, HFValidationError, RepositoryNotFoundError) from torch import nn @@ -31,6 +32,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + logger = init_logger(__name__) _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { @@ -190,3 +196,154 @@ def get_adapter_absolute_path(lora_path: str) -> str: return lora_path return local_snapshot_path + + +def compute_meta( + token_lora_tensor: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. + """ + + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( + token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + token_nums = seq_length_tensor.sum().item() + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, no_lora) + + +# TODO see if this can be vectorized +def convert_mapping( + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + device: torch.device, + long_lora_context: Optional["LongContextLoRAContext"] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). + """ + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) + prompt_mapping: List[int] = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device=device) + prompt_mapping_tensor = torch.tensor(prompt_mapping, + dtype=torch.long, + device=device) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.arange( + 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded)) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) + + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) From eaac5e589f564cf2604b194a576742ed071d537d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 28 Nov 2024 09:25:46 +0000 Subject: [PATCH 2/8] Add kwargs Signed-off-by: Jee Jee Li --- vllm/lora/punica_gpu.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/lora/punica_gpu.py b/vllm/lora/punica_gpu.py index bf9ee6860c465..e5f041fbcb03d 100644 --- a/vllm/lora/punica_gpu.py +++ b/vllm/lora/punica_gpu.py @@ -29,12 +29,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): Multi-LoRA, and to provide the interface for the punica kernel. """ - def __init__( - self, - max_num_batched_tokens: int, - max_batches: int, - device: Union[torch.device, str], - ): + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) From 42d99a35a1b45b21068eac6e66b17d0e8ab77bcd Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 28 Nov 2024 10:26:19 +0000 Subject: [PATCH 3/8] Add kwargs Signed-off-by: Jee Jee Li --- vllm/lora/punica_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/punica_base.py b/vllm/lora/punica_base.py index a0ab76061b0d5..775f6dc934ccc 100644 --- a/vllm/lora/punica_base.py +++ b/vllm/lora/punica_base.py @@ -26,7 +26,7 @@ class PunicaWrapperBase(ABC): """ def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str]): + device: Union[torch.device, str], **kwargs): self._token_lora_indices = torch.empty(max_num_batched_tokens, dtype=torch.long, device=device) From a30a95fd4918911e5629544016beaad91008c3f9 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 6 Dec 2024 05:50:53 +0000 Subject: [PATCH 4/8] Clean up code Signed-off-by: Jee Jee Li --- vllm/lora/punica_base.py | 255 +++++++++++++++++------ vllm/lora/punica_gpu.py | 431 +++++---------------------------------- 2 files changed, 237 insertions(+), 449 deletions(-) diff --git a/vllm/lora/punica_base.py b/vllm/lora/punica_base.py index e9e673ac0033e..0833b6cab8bf6 100644 --- a/vllm/lora/punica_base.py +++ b/vllm/lora/punica_base.py @@ -18,11 +18,108 @@ from vllm.lora.models import LongContextLoRAContext -class PunicaWrapperBase(ABC): +class PunicaWrapperABC(ABC): """ - PunicaWrapper is designed to manage and provide metadata for the punica + PunicaWrapper ABC. + """ + + @abstractmethod + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + """ + Update the lora-related metadata + """ + raise NotImplementedError + + @abstractmethod + def add_shrink( + self, + y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, + ): + """ + Performs GEMM for multiple slices of lora_a. + """ + + raise NotImplementedError + + @abstractmethod + def add_expand( + self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + ) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + ): + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA, + and this layer only requires the expand operation. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None: + """ + Applicable to linear-related lora. + """ + + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + """ + raise NotImplementedError + + +class PunicaWrapperBase(PunicaWrapperABC): + """ + PunicaWrapperBase is designed to manage and provide metadata for the punica kernel. The main function is to maintain the state information for - Multi-LoRA, and to provide the interface for the punica kernel. + Multi-LoRA, and to provide the interface for the punica. """ def __init__(self, max_num_batched_tokens: int, max_batches: int, @@ -65,26 +162,6 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self.is_prefill = False self.no_lora = False - def update_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - ): - - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) - if mapping.is_prefill: - # Update metadata required for prefill-related operators. - self._update_prefill_metada(self.token_lora_indices) - self.is_prefill = True - else: - self.is_prefill = False - def _update_base_metadata( self, mapping: "LoRAMapping", @@ -140,6 +217,38 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.token_nums = token_nums self.no_lora = no_lora + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output[:, offset_left:offset_left + slice] += bias + offset_left += slice + + return output.view_as(org_output) + @property def prefill_metadata( self @@ -204,21 +313,33 @@ def long_lora_indices(self) -> torch.Tensor: long_lora_len = self.indices_len[4] return self._long_lora_indices[:long_lora_len] + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + @abstractmethod - def add_shrink( - self, - y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, - ): + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): """ Performs GEMM for multiple slices of lora_a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. - + Semantics: for i in range(len(lora_a_stacked)): y[i] += (x @ lora_a_stacked[i]) * scale @@ -234,16 +355,15 @@ def add_shrink( raise NotImplementedError @abstractmethod - def add_expand( - self, - y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], - offset_start: int = 0, - add_input=True, - ) -> None: + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + **kwargs) -> None: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -264,19 +384,19 @@ def add_expand( add_input (bool): Defaults to True. """ + # TODO: implement it based on torch ops raise NotImplementedError @abstractmethod - def add_lora_embedding( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_input: bool = True, - ): + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + **kwargs): """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. - + and this layer only requires the expand operation. Semantics: y += x @ lora_b_stacked @@ -286,20 +406,21 @@ def add_lora_embedding( lora_b_stacked (torch.Tensor): lora_b's weights. add_input (bool): Default to True. """ + # TODO: implement it based on torch ops raise NotImplementedError @abstractmethod - def add_lora_linear( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - scale: float, - output_slices: Tuple[int, ...], - *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None: + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: """ Applicable to linear-related lora. @@ -322,7 +443,7 @@ def add_lora_linear( output_slices (Tuple[int, ...]): Every slice's size. buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. """ - + # TODO: implement it based on torch ops raise NotImplementedError @abstractmethod @@ -333,7 +454,8 @@ def add_lora_logits(self, lora_b_stacked: torch.Tensor, scale, *, - buffer: Optional[torch.Tensor] = None) -> None: + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. @@ -349,4 +471,5 @@ def add_lora_logits(self, scale (float): Scaling factor. buffer (Optional[torch.Tensor]):Default to None. """ + # TODO: implement it based on torch ops raise NotImplementedError diff --git a/vllm/lora/punica_gpu.py b/vllm/lora/punica_gpu.py index fbaf425df54bc..de4368021c82e 100644 --- a/vllm/lora/punica_gpu.py +++ b/vllm/lora/punica_gpu.py @@ -5,12 +5,11 @@ https://arxiv.org/abs/2310.18547 """ -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union,final +from typing import Callable, Optional, Tuple, Union, final import torch from vllm.lora.punica_base import PunicaWrapperBase - from vllm.triton_utils import HAS_TRITON if HAS_TRITON: @@ -21,168 +20,13 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink -if TYPE_CHECKING: - # avoid circuit import - from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext - - -def compute_meta( - token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: - """ - Get the information required for the sgmv kernel. With the features: - 1. If consecutive requests in the batch use the same LoRA, this function - will combine them into a single request, improving sgmv kernel inference - performance. - 2. At the beginning of each prefill stage inference, recalculations are - needed based on the input, but only once. - """ - - lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( - token_lora_tensor, return_counts=True) - cum_result = torch.cumsum(seq_length_tensor, dim=0) - b_seq_start_tensor = torch.zeros_like(seq_length_tensor) - b_seq_start_tensor[1:].copy_(cum_result[:-1]) - max_length = seq_length_tensor.max().item() - token_nums = seq_length_tensor.sum().item() - batch_size = lora_indices_tensor.size(0) - no_lora = False - # -1 means no lora should be applied. Use `no_lora` to determine whether - # the current step requires LoRA. If LoRA is not needed, the prefill stage - # does not need to launch the triton kernel, which can improve performance - if batch_size == 1 and lora_indices_tensor == -1: - no_lora = True - return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, no_lora) - - -# TODO see if this can be vectorized -def convert_mapping( - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - device: torch.device, - long_lora_context: Optional["LongContextLoRAContext"] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: - """Converts LoRAMapping to index tensors. - - Args: - mapping: LoRAMapping mapping rows in a batch to LoRA ids. - lora_index_to_id: List mapping LoRA ids to LoRA indices. - max_loras: Maximum number of LoRAs. - vocab_size: Model vocab size. - extra_vocab_size: Extra vocab size each LoRA can have. - long_lora_context: Passed if there are long context lora in a batch. - - Returns: - A tuple of tensors: - base_indices: Tensor of shape [batch_size] mapping batch rows to - LoRA indices. - sampler_indices: Tensor of shape [batch_size] mapping requests to - LoRA indices for sampler. For generation, this will be the - same as base_indicies. For prefill, this will map requests - to LoRA indices. - sampler_indices_padded: Tensor of shape [batch_size] mapping - requests to LoRA indices for sampler with padding. - Same as sampler_indicies, but -1 is replaced with - max_loras. - embeddings_indices: Tensor of shape [2, batch_size] mapping - requests to embedding indices. First row is for embeddings - added by the LoRAs, second row is for the LoRA.lora_a - embeddings. - long_lora_indices: Tensor of shape [batch_size] mapping - requests to RoPE offsets and rot dims for long LoRAs. - None if long context lora doesn't exist. - indices_len: List of lengths of the above tensors. It contains - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices). - """ - index_mapping_indices: List[int] = list(mapping.index_mapping).copy() - embedding_indices = index_mapping_indices.copy() - lora_indices = index_mapping_indices.copy() - long_lora_offsets: Optional[torch.Tensor] = None - if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) - prompt_mapping: List[int] = [ - lora_index_to_id.index(x) if x > 0 else -1 - for x in mapping.prompt_mapping - ] - lora_idx = None - for i in range(len(index_mapping_indices)): - # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) - embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 - lora_indices[i] = lora_idx - if long_lora_context: - assert long_lora_offsets is not None - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset - - indices_list: List[Union[List[int], torch.Tensor]] = [ - index_mapping_indices, - lora_indices, - embedding_indices, - ] - if long_lora_context: - assert long_lora_offsets is not None - indices_list.append(long_lora_offsets) - indices = torch.tensor(indices_list, dtype=torch.long, device=device) - prompt_mapping_tensor = torch.tensor(prompt_mapping, - dtype=torch.long, - device=device) - embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size), - ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 - base_indices = indices[1] - sampler_indices = prompt_mapping_tensor - sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 - sampler_indices_padded = torch.arange( - 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( - sampler_indices_padded * len(sampler_indices_padded)) - long_lora_indices = None - long_lora_indices_len: Optional[int] = None - if long_lora_context: - long_lora_indices = indices[3] - long_lora_indices_len = long_lora_indices.shape[-1] - # Contain length of indices tensors. Used to index into each tensor. - indices_len = [ - base_indices.shape[-1], - sampler_indices.shape[-1], - sampler_indices_padded.shape[-1], - embeddings_indices.shape[-1], - ] - if long_lora_indices_len is not None: - indices_len.append(long_lora_indices_len) - else: - # If long_lora doesn't exist,append None - indices_len.append(None) - - return ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - long_lora_indices, - indices_len, - ) @final class PunicaWrapperGPU(PunicaWrapperBase): """ - PunicaWrapper is designed to manage and provide metadata for the punica + PunicaWrapperGPU is designed to manage and provide metadata for the punica kernel. The main function is to maintain the state information for - Multi-LoRA, and to provide the interface for the punica kernel. + Multi-LoRA, and to provide the interface for the punica triton kernel. """ def __init__(self, max_num_batched_tokens: int, max_batches: int, @@ -190,146 +34,6 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - - def update_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - ): - - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) - if mapping.is_prefill: - # Update metadata required for prefill-related operators. - self._update_prefill_metada(self.token_lora_indices) - self.is_prefill = True - else: - self.is_prefill = False - - def _update_base_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - ): - ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - long_lora_offsets_tensor, - indices_len, - ) = convert_mapping( - mapping, - lora_index_to_id, - max_loras, - vocab_size, - extra_vocab_size, - self.device, - long_lora_context, - ) - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) - self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) - self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) - if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) - else: - self._long_lora_indices.zero_() - self.indices_len[:] = indices_len - - def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: - - (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, - no_lora) = compute_meta(token_lora_tensor) - - self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( - b_seq_start_tensor) - self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) - self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( - lora_indices_tensor) - self.batch_size = batch_size - self.max_length = max_length - self.token_nums = token_nums - self.no_lora = no_lora - - @property - def prefill_metadata( - self - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: - """ - This property provides a convenient way to access the necessary - metadata for prefill-related kernel computations. - 1. seq_start_locs: Tensor of sequence start positions. - 2. seq_lengths: Tensor of sequence lengths. - 3. lora_indices_per_batch: Tensor of lora indices, and an index of - -1 means no lora should be applied. - 4. batch_size: Batch size after clustering identical lora indices. - 5. max_length: The maximum sequence length in the batch. - 6. token_nums: The token numbers in the batch. - """ - return (self._seq_start_locs[:self.batch_size], - self._seq_lengths[:self.batch_size], - self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length, self.token_nums) - - @property - def token_lora_indices(self) -> torch.Tensor: - """ - This property provides the lora indices corresponding to each token - in the batch. An index of -1 means no lora should be applied. - """ - token_lora_len = self.indices_len[0] - return self._token_lora_indices[:token_lora_len] - - @property - def sampler_indices(self) -> torch.Tensor: - """ - This property is used to access the lora indices specifically for - LogitsProcessorWithLoRA. - """ - sampler_indices_len = self.indices_len[1] - return self._sampler_indices[:sampler_indices_len] - - @property - def sampler_indices_padded(self) -> torch.Tensor: - """ - This property provides access to padded sampler indices. - """ - indices_padded_len = self.indices_len[2] - return self._sampler_indices_padded[:indices_padded_len] - - @property - def embeddings_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA. - """ - embeddings_indices_len = self.indices_len[3] - return self._embeddings_indices[:, :embeddings_indices_len] - - @property - def long_lora_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLora. - """ - long_lora_len = self.indices_len[4] - return self._long_lora_indices[:long_lora_len] - def _shrink_prefill( self, y: torch.Tensor, @@ -418,13 +122,15 @@ def _expand_slice_decode( bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_input) - def _apply_expand(self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool = True): + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True, + ): """ Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` computation, which is suitable for the @@ -436,45 +142,8 @@ def _apply_expand(self, self._expand_slice_decode) expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) - def _apply_bias( - self, - indices: torch.Tensor, - output: torch.Tensor, - output_slices: Tuple[int, ...], - lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], - ): - """Applies bias to output - - Input shapes: - lora_bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = lora_bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias[indices == -1] = 0 - output[:, offset_left:offset_left + slice] += bias - offset_left += slice - - return output.view_as(org_output) - - def _apply_shrink( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): """ Perform the ` y+=x@w_t_all` computation, which is suitable for the GEMM of lora'a. @@ -490,13 +159,9 @@ def _apply_shrink( shrink_fun(y, x, w_t_all, scale) y = y.view_as(y_org) - def add_shrink( - self, - y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, - ): + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): """ Performs GEMM for multiple slices of lora_a. When `is_prefill is` true, it indicates that it is currently the @@ -521,16 +186,15 @@ def add_shrink( self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale) - def add_expand( - self, - y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], - offset_start: int = 0, - add_input=True, - ) -> None: + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + **kwargs) -> None: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -568,13 +232,12 @@ def add_expand( offset_left += output_slices[slice_idx] y = y.view_as(y_org) - def add_lora_embedding( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_input: bool = True, - ): + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + **kwargs) -> None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -593,17 +256,17 @@ def add_lora_embedding( if self.is_prefill else self._expand_decode) expand_fun(y, x, lora_b_stacked, add_input) - def add_lora_linear( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - scale: float, - output_slices: Tuple[int, ...], - *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None: + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: """ Applicable to linear-related lora. @@ -641,13 +304,14 @@ def add_lora_linear( torch.zeros( (x.size(0), r), dtype=torch.float32, device=x.device) for _ in range(len(output_slices))) - self.add_shrink(buffer, x, lora_a_stacked, scale) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) self.add_expand(y, buffer, lora_b_stacked, None, output_slices, - add_input=True) + add_input=True, + **kwargs) def add_lora_logits(self, y: torch.Tensor, @@ -656,7 +320,8 @@ def add_lora_logits(self, lora_b_stacked: torch.Tensor, scale, *, - buffer: Optional[torch.Tensor] = None) -> None: + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. From bff8f2e815cd98882e0742050570f221f9667fd4 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 6 Dec 2024 06:13:16 +0000 Subject: [PATCH 5/8] Clean up code Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 2 +- vllm/lora/layers.py | 2 +- vllm/lora/models.py | 2 +- vllm/lora/punica_selector.py | 11 -- vllm/lora/punica_wrapper/__init__.py | 10 ++ vllm/lora/{ => punica_wrapper}/punica_base.py | 2 +- vllm/lora/{ => punica_wrapper}/punica_gpu.py | 4 +- vllm/lora/punica_wrapper/punica_selector.py | 14 ++ vllm/lora/punica_wrapper/utils.py | 161 ++++++++++++++++++ vllm/lora/utils.py | 159 +---------------- 10 files changed, 193 insertions(+), 174 deletions(-) delete mode 100644 vllm/lora/punica_selector.py create mode 100644 vllm/lora/punica_wrapper/__init__.py rename vllm/lora/{ => punica_wrapper}/punica_base.py (99%) rename vllm/lora/{ => punica_wrapper}/punica_gpu.py (99%) create mode 100644 vllm/lora/punica_wrapper/punica_selector.py create mode 100644 vllm/lora/punica_wrapper/utils.py diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index d032be2bbea27..6ee7677ed2be0 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -28,7 +28,7 @@ # yapf: enable from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, PackedLoRALayerWeights) -from vllm.lora.punica_gpu import PunicaWrapperGPU +from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 267702856da57..2c25bf3313a96 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -32,7 +32,7 @@ VocabParallelEmbedding) if TYPE_CHECKING: - from vllm.lora.punica_base import PunicaWrapperBase + from vllm.lora.punica_wrapper import PunicaWrapperBase def _get_lora_device(base_layer: nn.Module) -> torch.device: diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 4bbeb41b5ab72..49cd9f0c236ad 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -21,7 +21,7 @@ LinearScalingRotaryEmbeddingWithLora, LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.punica_selector import get_punica_wrapper +from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) diff --git a/vllm/lora/punica_selector.py b/vllm/lora/punica_selector.py deleted file mode 100644 index 76cbf54b98fff..0000000000000 --- a/vllm/lora/punica_selector.py +++ /dev/null @@ -1,11 +0,0 @@ -from vllm.lora.punica_base import PunicaWrapperBase -from vllm.platforms import current_platform - - -def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: - if current_platform.is_cuda() or current_platform.is_rocm(): - from vllm.lora.punica_gpu import PunicaWrapperGPU - - return PunicaWrapperGPU(*args, **kwargs) - else: - raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/__init__.py b/vllm/lora/punica_wrapper/__init__.py new file mode 100644 index 0000000000000..8e0ffc7155305 --- /dev/null +++ b/vllm/lora/punica_wrapper/__init__.py @@ -0,0 +1,10 @@ +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase + +from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper + + + +__all__ = [ + "PunicaWrapperBase", + "get_punica_wrapper", +] diff --git a/vllm/lora/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py similarity index 99% rename from vllm/lora/punica_base.py rename to vllm/lora/punica_wrapper/punica_base.py index 0833b6cab8bf6..c85e897597f8d 100644 --- a/vllm/lora/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -10,7 +10,7 @@ import torch -from vllm.lora.utils import compute_meta, convert_mapping +from .utils import compute_meta, convert_mapping if TYPE_CHECKING: # avoid circuit import diff --git a/vllm/lora/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py similarity index 99% rename from vllm/lora/punica_gpu.py rename to vllm/lora/punica_wrapper/punica_gpu.py index de4368021c82e..e51368832337b 100644 --- a/vllm/lora/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -9,7 +9,7 @@ import torch -from vllm.lora.punica_base import PunicaWrapperBase + from vllm.triton_utils import HAS_TRITON if HAS_TRITON: @@ -20,6 +20,8 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from .punica_base import PunicaWrapperBase + @final class PunicaWrapperGPU(PunicaWrapperBase): diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py new file mode 100644 index 0000000000000..03b84e12697fb --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -0,0 +1,14 @@ + +from vllm.platforms import current_platform +from .punica_base import PunicaWrapperBase +from functools import lru_cache + + +@lru_cache(maxsize=None) +def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: + if current_platform.is_cuda_alike(): + from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU + + return PunicaWrapperGPU(*args, **kwargs) + else: + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py new file mode 100644 index 0000000000000..70ee4342cf40a --- /dev/null +++ b/vllm/lora/punica_wrapper/utils.py @@ -0,0 +1,161 @@ +import torch +from typing import List,Optional,Tuple,TYPE_CHECKING,Union + + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + + +def compute_meta( + token_lora_tensor: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. + """ + + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( + token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + token_nums = seq_length_tensor.sum().item() + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, no_lora) + + + +# TODO see if this can be vectorized +def convert_mapping( + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + device: torch.device, + long_lora_context: Optional["LongContextLoRAContext"] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). + """ + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) + prompt_mapping: List[int] = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device=device) + prompt_mapping_tensor = torch.tensor(prompt_mapping, + dtype=torch.long, + device=device) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.arange( + 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded)) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) + + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index a9492e5fb2fae..5876494ce2824 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,9 +1,8 @@ import os import re -from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type, Union +from typing import List, Optional, Set, Tuple, Type, Union import huggingface_hub -import torch from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, HFValidationError, RepositoryNotFoundError) from torch import nn @@ -32,11 +31,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -if TYPE_CHECKING: - # avoid circuit import - from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext - logger = init_logger(__name__) _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { @@ -196,154 +190,3 @@ def get_adapter_absolute_path(lora_path: str) -> str: return lora_path return local_snapshot_path - - -def compute_meta( - token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: - """ - Get the information required for the sgmv kernel. With the features: - 1. If consecutive requests in the batch use the same LoRA, this function - will combine them into a single request, improving sgmv kernel inference - performance. - 2. At the beginning of each prefill stage inference, recalculations are - needed based on the input, but only once. - """ - - lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( - token_lora_tensor, return_counts=True) - cum_result = torch.cumsum(seq_length_tensor, dim=0) - b_seq_start_tensor = torch.zeros_like(seq_length_tensor) - b_seq_start_tensor[1:].copy_(cum_result[:-1]) - max_length = seq_length_tensor.max().item() - token_nums = seq_length_tensor.sum().item() - batch_size = lora_indices_tensor.size(0) - no_lora = False - # -1 means no lora should be applied. Use `no_lora` to determine whether - # the current step requires LoRA. If LoRA is not needed, the prefill stage - # does not need to launch the triton kernel, which can improve performance - if batch_size == 1 and lora_indices_tensor == -1: - no_lora = True - return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, no_lora) - - -# TODO see if this can be vectorized -def convert_mapping( - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - device: torch.device, - long_lora_context: Optional["LongContextLoRAContext"] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: - """Converts LoRAMapping to index tensors. - - Args: - mapping: LoRAMapping mapping rows in a batch to LoRA ids. - lora_index_to_id: List mapping LoRA ids to LoRA indices. - max_loras: Maximum number of LoRAs. - vocab_size: Model vocab size. - extra_vocab_size: Extra vocab size each LoRA can have. - long_lora_context: Passed if there are long context lora in a batch. - - Returns: - A tuple of tensors: - base_indices: Tensor of shape [batch_size] mapping batch rows to - LoRA indices. - sampler_indices: Tensor of shape [batch_size] mapping requests to - LoRA indices for sampler. For generation, this will be the - same as base_indicies. For prefill, this will map requests - to LoRA indices. - sampler_indices_padded: Tensor of shape [batch_size] mapping - requests to LoRA indices for sampler with padding. - Same as sampler_indicies, but -1 is replaced with - max_loras. - embeddings_indices: Tensor of shape [2, batch_size] mapping - requests to embedding indices. First row is for embeddings - added by the LoRAs, second row is for the LoRA.lora_a - embeddings. - long_lora_indices: Tensor of shape [batch_size] mapping - requests to RoPE offsets and rot dims for long LoRAs. - None if long context lora doesn't exist. - indices_len: List of lengths of the above tensors. It contains - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices). - """ - index_mapping_indices: List[int] = list(mapping.index_mapping).copy() - embedding_indices = index_mapping_indices.copy() - lora_indices = index_mapping_indices.copy() - long_lora_offsets: Optional[torch.Tensor] = None - if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) - prompt_mapping: List[int] = [ - lora_index_to_id.index(x) if x > 0 else -1 - for x in mapping.prompt_mapping - ] - lora_idx = None - for i in range(len(index_mapping_indices)): - # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) - embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 - lora_indices[i] = lora_idx - if long_lora_context: - assert long_lora_offsets is not None - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset - - indices_list: List[Union[List[int], torch.Tensor]] = [ - index_mapping_indices, - lora_indices, - embedding_indices, - ] - if long_lora_context: - assert long_lora_offsets is not None - indices_list.append(long_lora_offsets) - indices = torch.tensor(indices_list, dtype=torch.long, device=device) - prompt_mapping_tensor = torch.tensor(prompt_mapping, - dtype=torch.long, - device=device) - embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size), - ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 - base_indices = indices[1] - sampler_indices = prompt_mapping_tensor - sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 - sampler_indices_padded = torch.arange( - 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( - sampler_indices_padded * len(sampler_indices_padded)) - long_lora_indices = None - long_lora_indices_len: Optional[int] = None - if long_lora_context: - long_lora_indices = indices[3] - long_lora_indices_len = long_lora_indices.shape[-1] - # Contain length of indices tensors. Used to index into each tensor. - indices_len = [ - base_indices.shape[-1], - sampler_indices.shape[-1], - sampler_indices_padded.shape[-1], - embeddings_indices.shape[-1], - ] - if long_lora_indices_len is not None: - indices_len.append(long_lora_indices_len) - else: - # If long_lora doesn't exist,append None - indices_len.append(None) - - return ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - long_lora_indices, - indices_len, - ) From c4d60e3ec2dbdc6e4954b654d19000df4b3b9bff Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 6 Dec 2024 06:47:49 +0000 Subject: [PATCH 6/8] Fix test code Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 16 ++++++++-------- vllm/lora/punica_wrapper/__init__.py | 3 --- vllm/lora/punica_wrapper/punica_gpu.py | 1 - vllm/lora/punica_wrapper/punica_selector.py | 3 ++- vllm/lora/punica_wrapper/utils.py | 6 ++---- 5 files changed, 12 insertions(+), 17 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 6ee7677ed2be0..e901f68852091 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -28,7 +28,7 @@ # yapf: enable from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, PackedLoRALayerWeights) -from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU +from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -205,7 +205,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapperGPU(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -305,7 +305,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapperGPU(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -441,7 +441,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapperGPU(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -571,7 +571,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage, torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapperGPU(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -683,7 +683,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapperGPU(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -805,7 +805,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapperGPU(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -963,7 +963,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, seed = 0 current_platform.seed_everything(seed) torch.set_default_device(device) - punica_wrapper = PunicaWrapperGPU(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, diff --git a/vllm/lora/punica_wrapper/__init__.py b/vllm/lora/punica_wrapper/__init__.py index 8e0ffc7155305..48ada3926ea46 100644 --- a/vllm/lora/punica_wrapper/__init__.py +++ b/vllm/lora/punica_wrapper/__init__.py @@ -1,8 +1,5 @@ from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase - from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper - - __all__ = [ "PunicaWrapperBase", diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index e51368832337b..b2af29de129ce 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -9,7 +9,6 @@ import torch - from vllm.triton_utils import HAS_TRITON if HAS_TRITON: diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index 03b84e12697fb..68ab3a6353595 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -1,7 +1,8 @@ +from functools import lru_cache from vllm.platforms import current_platform + from .punica_base import PunicaWrapperBase -from functools import lru_cache @lru_cache(maxsize=None) diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index 70ee4342cf40a..7360c8c09e3ac 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -1,6 +1,6 @@ -import torch -from typing import List,Optional,Tuple,TYPE_CHECKING,Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union +import torch if TYPE_CHECKING: # avoid circuit import @@ -8,7 +8,6 @@ from vllm.lora.models import LongContextLoRAContext - def compute_meta( token_lora_tensor: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: @@ -39,7 +38,6 @@ def compute_meta( batch_size, max_length, token_nums, no_lora) - # TODO see if this can be vectorized def convert_mapping( mapping: "LoRAMapping", From ef057fdaf556a4165e663f626c0148ee610ef88a Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 6 Dec 2024 07:38:01 +0000 Subject: [PATCH 7/8] Optimize logic Signed-off-by: Jee Jee Li --- vllm/lora/punica_wrapper/punica_base.py | 2 +- vllm/lora/punica_wrapper/punica_selector.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index c85e897597f8d..928cf1e6415e3 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -351,7 +351,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], scale (float): Scaling factor for the operation """ - + # TODO: implement it based on torch ops raise NotImplementedError @abstractmethod diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index 68ab3a6353595..df6c1bdc7dd71 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -1,15 +1,14 @@ -from functools import lru_cache - from vllm.platforms import current_platform +from vllm.utils import print_info_once from .punica_base import PunicaWrapperBase -@lru_cache(maxsize=None) def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: if current_platform.is_cuda_alike(): + # Lazy import to avoid ImportError from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU - + print_info_once("Using PunicaWrapperGPU.") return PunicaWrapperGPU(*args, **kwargs) else: raise NotImplementedError From 0d01cb9f5cb35667d96a3b87174aa585433c5974 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 6 Dec 2024 15:45:59 +0000 Subject: [PATCH 8/8] Add unit test Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 33 ++++++++++++++++----- vllm/lora/punica_wrapper/punica_base.py | 39 ++++++++++++++----------- 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index e901f68852091..fb8c0b2a7ba26 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -48,11 +48,12 @@ torch.float32: (5e-3, 5e-3), torch.bfloat16: (3e-2, 2e-2), } -CUDA_DEVICES = [ +# TODO: Modify this based on platform +DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -# We will launch different triton kernels between the prefill and decode +#For GPU, we will launch different triton kernels between the prefill and decode # stages, so we need to verify this. prefill stage(True) or decode stage(False) STAGES = [True, False] @@ -192,9 +193,18 @@ def create_random_inputs( return inputs, index_mapping, prompt_mapping +def check_punica_wrapper(punica_wrapper) -> bool: + if current_platform.is_cuda_alike(): + from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU + + return type(punica_wrapper) is PunicaWrapperGPU + else: + return False + + @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: @@ -206,6 +216,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: torch.set_default_device(device) max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -296,7 +307,7 @@ def create_random_embedding_layer(): # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings_with_new_embeddings(dist_init, num_loras, device, @@ -306,6 +317,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, torch.set_default_device(device) max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -432,7 +444,7 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, @@ -442,6 +454,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, torch.set_default_device(device) max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -563,7 +576,7 @@ def _pretest(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_replicated(dist_init, num_loras, device, stage, @@ -572,6 +585,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage, torch.cuda.set_device(device) torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -675,7 +689,7 @@ def create_random_linear_replicated_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, @@ -684,6 +698,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.cuda.set_device(device) torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -797,7 +812,7 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, @@ -806,6 +821,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.cuda.set_device(device) torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -964,6 +980,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, current_platform.seed_everything(seed) torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 928cf1e6415e3..0a5a84bdd8deb 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -32,7 +32,8 @@ def update_metadata( vocab_size: int, extra_vocab_size: int, long_lora_context: Optional["LongContextLoRAContext"] = None, - ): + **kwargs, + ) -> None: """ Update the lora-related metadata """ @@ -45,7 +46,8 @@ def add_shrink( x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, - ): + **kwargs, + ) -> None: """ Performs GEMM for multiple slices of lora_a. """ @@ -62,6 +64,7 @@ def add_expand( output_slices: Tuple[int, ...], offset_start: int = 0, add_input=True, + **kwargs, ) -> None: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -75,7 +78,8 @@ def add_lora_embedding( x: torch.Tensor, lora_b_stacked: torch.Tensor, add_input: bool = True, - ): + **kwargs, + ) -> None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA, and this layer only requires the expand operation. @@ -83,17 +87,17 @@ def add_lora_embedding( raise NotImplementedError @abstractmethod - def add_lora_linear( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - scale: float, - output_slices: Tuple[int, ...], - *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None: + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: """ Applicable to linear-related lora. """ @@ -108,7 +112,8 @@ def add_lora_logits(self, lora_b_stacked: torch.Tensor, scale, *, - buffer: Optional[torch.Tensor] = None) -> None: + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. """ @@ -336,7 +341,7 @@ def update_metadata( @abstractmethod def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, **kwargs): + scale: float, **kwargs) -> None: """ Performs GEMM for multiple slices of lora_a. @@ -393,7 +398,7 @@ def add_lora_embedding(self, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_input: bool = True, - **kwargs): + **kwargs) -> None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. and this layer only requires the expand operation.