From 571da8fc431ec36427ee1034a7779b23229b015e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 5 Dec 2024 21:22:28 +0800 Subject: [PATCH] [Misc][LoRA] Clean up the function interface of Punica (#10917) Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 42 ++- vllm/lora/fully_sharded_layers.py | 175 +++++----- vllm/lora/layers.py | 538 +++++++++++------------------- vllm/lora/models.py | 8 +- vllm/lora/punica.py | 365 ++++++++++---------- 5 files changed, 497 insertions(+), 631 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 15e576cb065c7..a113e3f7abc1e 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -565,7 +565,9 @@ def _pretest(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("stage", STAGES) -def test_linear_replicated(dist_init, num_loras, device, stage) -> None: +@pytest.mark.parametrize("bias_enabled", [True, False]) +def test_linear_replicated(dist_init, num_loras, device, stage, + bias_enabled) -> None: torch.cuda.set_device(device) torch.set_default_device(device) @@ -573,7 +575,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None: max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, - lora_dtype=torch.float16) + lora_dtype=torch.float16, + bias_enabled=bias_enabled) def create_random_linear_replicated_layer(): @@ -585,7 +588,12 @@ def create_random_linear_replicated_layer(): lora_linear = ReplicatedLinearWithLoRA(linear) lora_linear.create_lora_weights(max_loras, lora_config) - + assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( + lora_linear.lora_b_stacked) == 1) + if bias_enabled: + assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices + else: + assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(10): @@ -669,8 +677,9 @@ def create_random_linear_replicated_layer(): @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("stage", STAGES) +@pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage) -> None: + device, stage, bias_enabled) -> None: torch.cuda.set_device(device) torch.set_default_device(device) @@ -679,7 +688,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, fully_sharded_loras=fully_shard, - lora_dtype=torch.float16) + lora_dtype=torch.float16, + bias_enabled=bias_enabled) def create_random_linear_parallel_layer(): if orientation == "row": @@ -700,7 +710,12 @@ def create_random_linear_parallel_layer(): if not fully_shard else ColumnParallelLinearWithShardedLoRA(linear)) lora_linear.create_lora_weights(max_loras, lora_config) - + assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( + lora_linear.lora_b_stacked) == 1) + if bias_enabled: + assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices + else: + assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(10): @@ -784,8 +799,9 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("stage", STAGES) +@pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage) -> None: + device, stage, bias_enabled) -> None: torch.cuda.set_device(device) torch.set_default_device(device) @@ -794,7 +810,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, fully_sharded_loras=fully_shard, - lora_dtype=torch.float16) + lora_dtype=torch.float16, + bias_enabled=bias_enabled) def create_column_parallel_packed_layer(): if repeats == 2: @@ -832,10 +849,16 @@ class FakeConfig: num_key_value_heads = 32 num_attention_heads = 32 + n_slices = repeats lora_linear.create_lora_weights(max_loras, lora_config, model_config=FakeConfig()) - + assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( + lora_linear.lora_b_stacked) == n_slices) + if bias_enabled: + assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices + else: + assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(10): @@ -911,7 +934,6 @@ class FakeConfig: 512, lora_config.lora_extra_vocab_size, ) - # lora_linear.set_mapping(*mapping_info) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index e25e453201f01..545ec21ca74c1 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -1,5 +1,5 @@ # pylint: disable=unused-argument -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast import torch import torch.nn as nn @@ -32,6 +32,44 @@ def dec(*args, **kwargs): return dec +def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): + """ + For `ColumnParallelLinearWithLoRA` or classes that inherit from + `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. + """ + assert (layer.n_slices == len(layer.lora_a_stacked) == len( + layer.lora_b_stacked) == len(layer.output_slices)) + if layer.lora_bias_stacked is not None: + assert layer.n_slices == len(layer.lora_bias_stacked) + + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + + # Since communication is needed, the buffer is directly initialized as a + # tensor rather than a tuple of tensor. + buffers = torch.zeros( + (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0) + buffers = tensor_model_parallel_all_gather(buffers) + layer.punica_wrapper.add_expand(output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + # these layers are based on the tensor parallelism strategy given in # Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, # https://arxiv.org/abs/2311.03285. @@ -51,34 +89,15 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): # gather operation. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked.shape[2] + shard_size = self.lora_a_stacked[0].shape[2] start_idx = tp_rank * shard_size lora_a = lora_a[:, start_idx:start_idx + shard_size] return lora_a - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, - output.shape[-1]), output.shape - buffer = torch.zeros( - (x.shape[0], self.lora_a_stacked.shape[2]), - dtype=torch.float32, - device=x.device, - ) - self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) - buffer = tensor_model_parallel_all_gather(buffer) - self.punica_wrapper.add_expand(output, - buffer, - self.lora_b_stacked, - self.bias_stacked, - add_input=True) - # now have column partitioned output - - output = output.view(*out_orig_shape) - return output + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) @classmethod @_fully_sharded_can_replace @@ -99,46 +118,6 @@ def can_replace_layer( ) -def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora): - """ - MergedColumnParallelLinearWithShardedLoRA and - MergedQKVParallelLinearWithShardedLora share the same - LoRa weight application method. - - The main difference is the step by shard_size for lora_b which can - vary for MergedQKVParallelLinearWithShardedLora but is constant for - MergedColumnParallelLinearWithShardedLoRA. - """ - # expecting 2 for column parallel and 3 for qkv - n = len(layer.lora_a_stacked) - output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - buffers = torch.zeros( - (n, x.shape[0], layer.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device, - ) - for idx in range(n): - layer.punica_wrapper.add_shrink(buffers[idx], x, - layer.lora_a_stacked[idx], 1.0) - - buffers = tensor_model_parallel_all_gather(buffers) - layer.punica_wrapper.add_expand_packed_nslice( - output, - buffers, - layer.lora_b_stacked, - layer.bias_stacked, - 1.0, - layer.output_slices, - ) - - output = output.view(*out_orig_shape) - # now have column partitioned and packed output - return output - - class MergedColumnParallelLinearWithShardedLoRA( MergedColumnParallelLinearWithLoRA): """ @@ -162,8 +141,9 @@ def slice_lora_a( ] return lora_a - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod @@ -195,31 +175,15 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked.shape[2] + shard_size = self.lora_a_stacked[0].shape[2] start_idx = tp_rank * shard_size lora_a = lora_a[:, start_idx:start_idx + shard_size] return lora_a - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, - output.shape[-1]), output.shape - buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), - dtype=torch.float32, - device=x.device) - self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) - buffer = tensor_model_parallel_all_gather(buffer) - self.punica_wrapper.add_expand(output, - buffer, - self.lora_b_stacked, - self.bias_stacked, - add_input=True) - # now have column partitioned output - output = output.view(*out_orig_shape) - return output + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) @classmethod @_fully_sharded_can_replace @@ -260,8 +224,9 @@ def slice_lora_a( ] return lora_a - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod @@ -294,7 +259,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): """ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - shard_size = self.lora_b_stacked.shape[2] + shard_size = self.lora_b_stacked[0].shape[2] start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size lora_b = lora_b[:, start_idx:end_idx] @@ -303,20 +268,24 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: if bias is None: return bias - shard_size = self.bias_stacked.shape[2] + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + shard_size = self.lora_bias_stacked[0].shape[2] start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size bias = bias[start_idx:end_idx] return bias - def apply(self, x: torch.Tensor) -> torch.Tensor: + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x) x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape buffer = torch.zeros( - (x.shape[0], self.lora_a_stacked.shape[2]), + (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), dtype=torch.float32, device=x.device, ) @@ -330,12 +299,18 @@ def apply(self, x: torch.Tensor) -> torch.Tensor: # remains is a standard all_reduce. User should be aware though that # the output is not the same as a normal row_parallel, it should be # reduced before being used - shard_size = self.lora_b_stacked.shape[2] - start_idx = self.tp_rank * shard_size - self.punica_wrapper.add_expand_slice(output, buffer, - self.lora_b_stacked, - self.bias_stacked, start_idx, - shard_size) + # NOTE offset are based on the rank. + shard_size = self.lora_b_stacked[0].shape[2] + offset_start = self.tp_rank * shard_size + self.punica_wrapper.add_expand( + output, + buffer, + self.lora_b_stacked, + self.lora_bias_stacked, + self.output_slices, + offset_start=offset_start, + add_input=True, + ) output = output.view(*out_orig_shape) return output diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 73748b5ce511e..473e4bedf3d60 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,7 +1,7 @@ # pylint: disable=unused-argument import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast import torch import torch.nn as nn @@ -18,11 +18,14 @@ tensor_model_parallel_gather) from vllm.distributed.utils import divide from vllm.lora.punica import PunicaWrapper +# yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) +# yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import ( LinearScalingRotaryEmbedding, RotaryEmbedding) @@ -249,13 +252,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings.shape[1], -1, ) - - # Embedding layer only need expand op - self.punica_wrapper.add_expand(full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - bias_all=None, - add_input=True) + self.punica_wrapper.add_lora_embedding(full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) return full_output.view_as(full_output_org) @classmethod @@ -269,14 +269,19 @@ def can_replace_layer( return type(source_layer) is VocabParallelEmbedding -class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): +class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): - def __init__(self, base_layer: ReplicatedLinear) -> None: + def __init__(self, base_layer: LinearBase): super().__init__() self.base_layer = base_layer self.input_size = self.base_layer.input_size - self.output_size = self.base_layer.output_size self.device = _get_lora_device(self.base_layer) + self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None + + self.output_slices: Tuple[int, ...] + self.tp_size: int + self.output_size: int + self.n_slices: int def create_lora_weights( self, @@ -285,39 +290,64 @@ def create_lora_weights( model_config: Optional[PretrainedConfig] = None, ) -> None: self.lora_config = lora_config - lora_a_output_size = lora_config.max_lora_rank - self.lora_a_stacked = torch.zeros( - max_loras, - 1, - lora_a_output_size, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.lora_b_stacked = torch.zeros( - max_loras, - 1, - self.output_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) - if lora_config.bias_enabled: - self.bias_stacked = torch.zeros( + # + if isinstance(self.base_layer, ReplicatedLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, ColumnParallelLinear): + lora_a_out_size = (lora_config.max_lora_rank if + not lora_config.fully_sharded_loras else divide( + lora_config.max_lora_rank, self.tp_size)) + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, RowParallelLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = (self.output_size if + not lora_config.fully_sharded_loras else divide( + self.output_size, self.tp_size)) + else: + raise NotImplementedError + + self.lora_a_stacked = tuple( + torch.zeros( max_loras, 1, - self.output_size, + lora_a_out_size, + self.input_size, dtype=lora_config.lora_dtype, device=self.device, - ) - else: - self.bias_stacked = None + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_b_out_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + if lora_config.bias_enabled: + lora_bias_out_size = lora_b_out_size + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_bias_out_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.output_slices = (self.lora_b_stacked[0].shape[2], ) def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - if self.lora_config.bias_enabled: - self.bias_stacked[index] = 0 + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + if self.lora_config.bias_enabled: + # Make mypy happy + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[s_index][index] = 0 def set_lora( self, @@ -325,29 +355,56 @@ def set_lora( lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, + lora_bias: Optional[torch.Tensor] = None, ): - self.reset_lora(index) + # Except for QKVParallelLinearWithLora and + # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers + # store weights in a tuple of size 1. These two layers will + # override this function. + assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == + self.n_slices == 1) - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if bias is not None: - self.bias_stacked[index, - 0, :bias.shape[0]].copy_(bias.T, - non_blocking=True) - - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + self.lora_a_stacked[0][index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[0][index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if lora_bias is not None: + + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + assert len(self.lora_bias_stacked) + self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( + lora_bias.T, non_blocking=True) + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, self.bias_stacked, - 1.0) + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, 1.0, + self.output_slices) return output + +class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__(base_layer, ) + # To ensure interface compatibility, set to 1 always. + self.tp_size = 1 + self.output_size = self.base_layer.output_size + self.n_slices = 1 + def forward(self, input_): """Forward of ReplicatedLinearWithLoRA @@ -380,73 +437,26 @@ def can_replace_layer( return type(source_layer) is ReplicatedLinear -class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): +class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): """ LoRA on top of ColumnParallelLinear layer. - LoRA B is sliced for tensor parallelism. + There are two types for the `base_layer`: + 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. + 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. """ def __init__(self, base_layer: ColumnParallelLinear) -> None: - super().__init__() + super().__init__(base_layer) # The base_layer type is ColumnParallelLinear or # MergedColumnParallelLinear, their weight sharding logic is # inconsistent when TP is greater than 1. self.is_merged_col_linear = type( base_layer) is MergedColumnParallelLinear - - self.base_layer = base_layer self.tp_size = get_tensor_model_parallel_world_size() - self.input_size = self.base_layer.input_size self.output_size = self.base_layer.output_size_per_partition - self.device = _get_lora_device(self.base_layer) - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - self.lora_config = lora_config - self.tp_size = get_tensor_model_parallel_world_size() - lora_a_output_size_per_partition = ( - lora_config.max_lora_rank if not lora_config.fully_sharded_loras - else divide(lora_config.max_lora_rank, self.tp_size)) - self.lora_a_stacked = torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.lora_b_stacked = torch.zeros( - max_loras, - 1, - self.output_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) - - if lora_config.bias_enabled: - self.bias_stacked = torch.zeros( - max_loras, - 1, - self.output_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) - else: - self.bias_stacked = None - - self.output_dim = self.lora_b_stacked.shape[2] - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - if self.lora_config.bias_enabled: - self.bias_stacked[index] = 0 + # There is only one LoRA layer + self.n_slices = 1 def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: return lora_a @@ -485,40 +495,6 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: bias = bias[start_idx:end_idx] return bias - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - bias = self.slice_bias(bias) - - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if bias is not None: - self.bias_stacked[index, - 0, :bias.shape[0]].copy_(bias.T, - non_blocking=True) - - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, self.bias_stacked, - 1.0) - return output - def forward(self, input_): """Forward of ColumnParallelLinear @@ -568,6 +544,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): def __init__(self, base_layer: MergedColumnParallelLinear) -> None: super().__init__(base_layer) + # There are two LoRA layers + self.n_slices = len(self.base_layer.output_sizes) def create_lora_weights( self, @@ -575,9 +553,13 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: + """ + The main reason for overriding this function is to enhance code + maintainability. + """ self.lora_config = lora_config - n_slices = 2 - if not (len(self.base_layer.output_sizes) == n_slices + + if not (len(self.base_layer.output_sizes) == self.n_slices == 2 and self.base_layer.output_sizes[0] == self.base_layer.output_sizes[1]): raise ValueError( @@ -598,7 +580,7 @@ def create_lora_weights( self.input_size, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(n_slices)) + ) for _ in range(self.n_slices)) self.lora_b_stacked = tuple( torch.zeros( max_loras, @@ -607,30 +589,19 @@ def create_lora_weights( lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(n_slices)) + ) for _ in range(self.n_slices)) if lora_config.bias_enabled: - self.bias_stacked = tuple( + self.lora_bias_stacked = tuple( torch.zeros( max_loras, 1, self.output_size // 2, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(n_slices)) - else: - self.bias_stacked = None + ) for _ in range(self.n_slices)) self.output_dim = self.lora_b_stacked[0].shape[2] self.output_slices = (self.output_dim, self.output_dim) - def reset_lora(self, index: int): - self.lora_a_stacked[0][index] = 0 - self.lora_a_stacked[1][index] = 0 - self.lora_b_stacked[0][index] = 0 - self.lora_b_stacked[1][index] = 0 - if self.lora_config.bias_enabled: - self.bias_stacked[0][index] = 0 - self.bias_stacked[1][index] = 0 - def slice_lora_a( self, lora_a: List[Union[torch.Tensor, None]] ) -> List[Union[torch.Tensor, None]]: @@ -668,15 +639,15 @@ def set_lora( lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, + lora_bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) - if bias is not None: - bias = self.slice_bias(bias) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) if lora_a[0] is not None: self.lora_a_stacked[0][ @@ -685,10 +656,11 @@ def set_lora( self.lora_b_stacked[0][ index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( lora_b[0].T, non_blocking=True) - if bias is not None and bias[0] is not None: - self.bias_stacked[0][index, - 0, :bias[0].shape[0]].copy_(bias[0].T, - non_blocking=True) + if lora_bias is not None and lora_bias[0] is not None: + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[0][index, 0, :lora_bias[0].shape[0]].copy_( + lora_bias[0].T, non_blocking=True) if lora_a[1] is not None: self.lora_a_stacked[1][ index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( @@ -696,18 +668,11 @@ def set_lora( self.lora_b_stacked[1][ index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( lora_b[1].T, non_blocking=True) - if bias is not None and bias[1] is not None: - self.bias_stacked[1][index, - 0, :bias[1].shape[0]].copy_(bias[1].T, - non_blocking=True) - - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - self.punica_wrapper.add_lora_packed_nslice( - output, x, self.lora_a_stacked, self.lora_b_stacked, - self.bias_stacked, 1.0, (self.output_dim, self.output_dim)) - return output + if lora_bias is not None and lora_bias[1] is not None: + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[1][index, 0, :lora_bias[1].shape[0]].copy_( + lora_bias[1].T, non_blocking=True) @classmethod @_not_fully_sharded_can_replace @@ -737,7 +702,6 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): def __init__(self, base_layer: QKVParallelLinear) -> None: super().__init__(base_layer) - self.tp_size = get_tensor_model_parallel_world_size() self.q_proj_total_size = (self.base_layer.total_num_heads * self.base_layer.head_size) self.q_proj_shard_size = (self.base_layer.num_heads * @@ -746,6 +710,8 @@ def __init__(self, base_layer: QKVParallelLinear) -> None: self.base_layer.head_size) self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * self.base_layer.head_size) + # There is only one LoRA layer + self.n_slices = 1 def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() @@ -780,32 +746,6 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: bias = torch.cat([bias_q, bias_k, bias_v], dim=1) return bias - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if bias is not None: - bias = self.slice_bias(bias) - - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if bias is not None: - self.bias_stacked[index, - 0, :bias.shape[0]].copy_(bias.T, - non_blocking=True) - @classmethod @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, @@ -828,6 +768,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): def __init__(self, base_layer: QKVParallelLinear) -> None: super().__init__(base_layer) + # There are three LoRA layer. + self.n_slices = len(self.base_layer.output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() def create_lora_weights( self, @@ -835,9 +779,16 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ self.lora_config = lora_config - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() + + if not (len(self.base_layer.output_sizes) == self.n_slices == 3): + raise ValueError( + "LoRAColumnParallelLinear3Slice requires 3 slices.") + self.q_proj_shard_size = (self.base_layer.num_heads * self.base_layer.head_size) self.kv_proj_shard_size = (self.base_layer.num_kv_heads * @@ -902,7 +853,7 @@ def create_lora_weights( ), ) if lora_config.bias_enabled: - self.bias_stacked = ( + self.lora_bias_stacked = ( torch.zeros( max_loras, 1, @@ -925,9 +876,6 @@ def create_lora_weights( device=self.device, ), ) - else: - self.bias_stacked = None - self.output_slices = ( self.q_proj_shard_size, self.kv_proj_shard_size, @@ -939,18 +887,6 @@ def create_lora_weights( self.indices: torch.Tensor self.indices_len: List[int] - def reset_lora(self, index: int): - self.lora_a_stacked[0][index] = 0 - self.lora_b_stacked[0][index] = 0 - self.lora_a_stacked[1][index] = 0 - self.lora_b_stacked[1][index] = 0 - self.lora_a_stacked[2][index] = 0 - self.lora_b_stacked[2][index] = 0 - if self.lora_config.bias_enabled: - self.bias_stacked[0][index] = 0 - self.bias_stacked[1][index] = 0 - self.bias_stacked[2][index] = 0 - def slice_lora_a( self, lora_a: List[Union[torch.Tensor, None]] ) -> List[Union[torch.Tensor, None]]: @@ -1000,15 +936,15 @@ def set_lora( lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, + lora_bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) - if bias is not None: - bias = self.slice_bias(bias) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) if lora_b[0] is not None: lora_b_q = lora_b[0] @@ -1039,26 +975,24 @@ def set_lora( index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( lora_a[2].T, non_blocking=True) - if bias is not None: - if bias[0] is not None: - self.bias_stacked[0][index, 0, :bias[0].shape[0]].copy_( - bias[0].T, non_blocking=True) - if bias[1] is not None: - self.bias_stacked[1][index, 0, :bias[1].shape[0]].copy_( - bias[1].T, non_blocking=True) - if bias[2] is not None: - self.bias_stacked[2][index, 0, :bias[2].shape[0]].copy_( - bias[2].T, non_blocking=True) - - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - self.punica_wrapper.add_lora_packed_nslice(output, x, - self.lora_a_stacked, - self.lora_b_stacked, - self.bias_stacked, 1.0, - self.output_slices) - return output + if lora_bias is not None: + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + if lora_bias[0] is not None: + self.lora_bias_stacked[0][index, + 0, :lora_bias[0].shape[0]].copy_( + lora_bias[0].T, + non_blocking=True) + if lora_bias[1] is not None: + self.lora_bias_stacked[1][index, + 0, :lora_bias[1].shape[0]].copy_( + lora_bias[1].T, + non_blocking=True) + if lora_bias[2] is not None: + self.lora_bias_stacked[2][index, + 0, :lora_bias[2].shape[0]].copy_( + lora_bias[2].T, + non_blocking=True) @classmethod @_not_fully_sharded_can_replace @@ -1073,76 +1007,25 @@ def can_replace_layer( and len(packed_modules_list) == 3) -class RowParallelLinearWithLoRA(BaseLayerWithLoRA): +class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: - super().__init__() - self.base_layer = base_layer + super().__init__(base_layer) + + self.tp_size = get_tensor_model_parallel_world_size() + # reset input_size self.input_size = self.base_layer.input_size_per_partition self.output_size = self.base_layer.output_size - self.device = _get_lora_device(self.base_layer) - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - self.lora_config = lora_config self.tp_rank = get_tensor_model_parallel_rank() - self.lora_a_stacked = torch.zeros( - ( - max_loras, - 1, - lora_config.max_lora_rank, - self.input_size, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - tp_size = get_tensor_model_parallel_world_size() - lora_b_output_size_per_partition = ( - self.output_size if not lora_config.fully_sharded_loras else - divide(self.output_size, tp_size)) - - self.lora_b_stacked = torch.zeros( - ( - max_loras, - 1, - lora_b_output_size_per_partition, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - - if lora_config.bias_enabled: - self.bias_stacked = torch.zeros( - ( - max_loras, - 1, - self.output_size, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - else: - self.bias_stacked = None - # Lazily initialized - self.indices: torch.Tensor - self.indices_len: List[int] - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - if self.lora_config.bias_enabled: - self.bias_stacked[index] = 0 + # There is only one LoRA layer. + self.n_slices = 1 def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.input_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size lora_a = lora_a[start_idx:end_idx, :] return lora_a @@ -1152,40 +1035,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - - if self.base_layer.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if bias is not None: - bias = self.slice_bias(bias) - - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if bias is not None: - self.bias_stacked[index, - 0, :bias.shape[0]].copy_(bias.T, - non_blocking=True) - - def apply(self, x: torch.Tensor) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x) - self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, self.bias_stacked, - 1.0) - return output - def forward(self, input_): """Forward of RowParallelLinear @@ -1203,10 +1052,9 @@ def forward(self, input_): input_parallel = input_ else: # TODO: simplify code below - tp_rank = get_tensor_model_parallel_rank() splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.base_layer.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() + input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. output_parallel = self.apply(input_parallel) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 2ffefe61427e3..9855b57d0c9c9 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -555,17 +555,17 @@ def create_dummy_lora( input_dim, output_dim, rank, - module.lora_a_stacked.dtype, + module.lora_a_stacked[0].dtype, "cpu", embeddings_tensor_dim=embeddings_tensor_dim, bias_enabled=bias_enabled) else: lora = LoRALayerWeights.create_dummy_lora_weights( module_name, - module.lora_a_stacked.shape[-1], - module.lora_b_stacked.shape[-2], + module.lora_a_stacked[0].shape[-1], + module.lora_b_stacked[0].shape[-2], rank, - module.lora_a_stacked.dtype, + module.lora_a_stacked[0].dtype, "cpu", bias_enabled=bias_enabled, ) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 3f775b7ba363e..563d1181d6fcb 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -362,7 +362,7 @@ def long_lora_indices(self) -> torch.Tensor: long_lora_len = self.indices_len[4] return self._long_lora_indices[:long_lora_len] - def shrink_prefill( + def _shrink_prefill( self, y: torch.Tensor, x: torch.Tensor, @@ -380,7 +380,7 @@ def shrink_prefill( scale, ) - def shrink_decode( + def _shrink_decode( self, y: torch.Tensor, x: torch.Tensor, @@ -389,7 +389,7 @@ def shrink_decode( ): bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - def expand_prefill( + def _expand_prefill( self, y: torch.Tensor, x: torch.Tensor, @@ -407,7 +407,7 @@ def expand_prefill( add_input, ) - def expand_decode( + def _expand_decode( self, y: torch.Tensor, x: torch.Tensor, @@ -416,7 +416,7 @@ def expand_decode( ): bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) - def expand_slice_prefill( + def _expand_slice_prefill( self, y: torch.Tensor, x: torch.Tensor, @@ -438,7 +438,7 @@ def expand_slice_prefill( add_input, ) - def expand_slice_decode( + def _expand_slice_decode( self, y: torch.Tensor, x: torch.Tensor, @@ -450,41 +450,35 @@ def expand_slice_decode( bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_input) - def apply_bias( - self, - indices: torch.Tensor, - output: torch.Tensor, - bias_stacked: torch.Tensor, - ): - """Applies bias to output - - Input shapes: - bias_stacked: (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, output_dim) + def _apply_expand(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1]) - bias_stacked = bias_stacked[indices] - bias_stacked[indices == -1] = 0 - output += bias_stacked - return output.view_as(org_output) + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) - def apply_bias_packed_nslice( + def _apply_bias( self, indices: torch.Tensor, output: torch.Tensor, output_slices: Tuple[int, ...], - bias_stacked: Tuple[Optional[torch.Tensor], ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], ): """Applies bias to output Input shapes: - bias_stacked: 3 element tuple of (num_loras, output_dim) + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) indices: (batch_size) output: (batch_size, q_slice_size + 2*kv_slice_size) output_slices: n-1 element tuple of (slice_size...), @@ -496,7 +490,7 @@ def apply_bias_packed_nslice( offset_left = 0 for slice_idx, slice in enumerate(output_slices): - bias = bias_stacked[slice_idx] + bias = lora_bias_stacked[slice_idx] if bias is not None: bias = bias.view(-1, bias.shape[-1]) bias = bias[indices] @@ -506,7 +500,7 @@ def apply_bias_packed_nslice( return output.view_as(org_output) - def add_shrink( + def _apply_shrink( self, y: torch.Tensor, x: torch.Tensor, @@ -517,188 +511,215 @@ def add_shrink( Perform the ` y+=x@w_t_all` computation, which is suitable for the GEMM of lora'a. When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the shrink_decode function + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function should be called. """ - shrink_fun: Callable = (self.shrink_prefill - if self.is_prefill else self.shrink_decode) + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) - def add_expand( + def add_shrink( self, - y: torch.Tensor, + y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, - w_t_all: torch.Tensor, - bias_all: Optional[torch.Tensor], - add_input: bool = True, + lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, ): """ - Perform the ` y+=x@w_t_all+bias` computation, which is suitable for the - GEMM of lora'b. - When `is_prefill` is true, it indicates that it is currently the - prefill stage, and the `expand_prefill` function should be called. - Otherwise, it is the decode stage, and the expand_decode function + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function should be called. - """ - if bias_all is not None: - y = self.apply_bias(self.token_lora_indices, y, bias_all) - - expand_fun: Callable = (self.expand_prefill - if self.is_prefill else self.expand_decode) - expand_fun(y, x, w_t_all, add_input) - - def add_expand_slice(self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - bias_all: Optional[torch.Tensor], - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool = True): - """ - Similar to `add_expand` - """ - if bias_all is not None: - y = self.apply_bias(self.token_lora_indices, y, bias_all) + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ - expand_slice_fun: Callable = (self.expand_slice_prefill - if self.is_prefill else - self.expand_slice_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) - def add_expand_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, - lora_b_stacked: Tuple[torch.Tensor, ...], - bias_stacked: Optional[Tuple[torch.Tensor, - ...]], - scale: float, - output_slices: Tuple[int, ...]) -> None: - """ - Similar to `add_expand` + def add_expand( + self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + ) -> None: """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_input (bool): Defaults to True. + """ y_org = y y = y.view(-1, y.shape[-1]) - offset_left = 0 - if bias_stacked is not None: - self.apply_bias_packed_nslice(self.token_lora_indices, y, - output_slices, bias_stacked) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - self.add_expand_slice(y, - x[slice_idx], - lora_b_stacked[slice_idx], - None, - offset_left, - output_slices[slice_idx], - add_input=True) + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_input=add_input, + ) offset_left += output_slices[slice_idx] - y = y.view_as(y_org) - def add_lora(self, - y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - bias_all: Optional[torch.Tensor], - scale: float, - y_offset: Optional[int] = None, - y_slice_size: Optional[int] = None, - *, - buffer: Optional[torch.Tensor] = None) -> None: + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + ): + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_input (bool): Default to True. + + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_input) + + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None: """ + Applicable to linear-related lora. + Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0)+bias[i] + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + Args: - y (torch.Tensor): Output tensor. Will be changed in-place. + y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor - wa_t_all (torch.Tensor): lora_a's weight - wb_t_all (torch.Tensor): lora_b's weight - bias_all: (torch.Tensor): lora's bias + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. - y_offset (Optional[int], optional): Offset to apply to the starting - column of y. - y_slice_size (Optional[int], optional): Size of the y column slice. - buffer (Optional[torch.Tensor], optional): Defaults to None. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. """ - y_org = y - y = y.view(-1, y.shape[-1]) - x = x.view(-1, x.shape[-1]) - r = wb_t_all.size(-1) + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + if buffer is None: + r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default ,refer to: # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - if bias_all is not None: - y = self.apply_bias(self.token_lora_indices, y, bias_all) - self.add_shrink(buffer, x, wa_t_all, scale) - if y_offset is None and y_slice_size is None: - self.add_expand(y, buffer, wb_t_all, bias_all=None, add_input=True) - else: - self.add_expand_slice(y, - buffer, - wb_t_all, - None, - y_offset, - y_slice_size, - add_input=True) - y = y.view_as(y_org) - - def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - bias_all: Tuple[Optional[torch.Tensor], - ...], scale: float, - output_slices: Tuple[int, ...]) -> None: - """ - Applies lora to each input. Similar to add_lora, This method is - used for layers that are composed of multiple sublayers - (slices) packed together. - """ - y_org = y - x = x.view(-1, x.shape[-1]) - y = y.view(-1, y.shape[-1]) - offset_left = 0 - if bias_all is not None: - y = self.apply_bias_packed_nslice(self.token_lora_indices, y, - output_slices, bias_all) - # TODO fuse these kernels - for slice_idx in range(len(output_slices)): - self.add_lora(y, x, lora_a_stacked[slice_idx], - lora_b_stacked[slice_idx], None, scale, offset_left, - output_slices[slice_idx]) - offset_left += output_slices[slice_idx] - - y = y.view_as(y_org) + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_input=True) def add_lora_logits(self, y: torch.Tensor, x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, scale, *, buffer: Optional[torch.Tensor] = None) -> None: """ - LogitsProcessorWithLoRA always using bgmv - """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - r = wb_t_all.size(-1) + r = lora_b_stacked.size(-1) if buffer is None: # We set the buffer to be float32 by default ,refer to: # https://github.com/triton-lang/triton/issues/1387 buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - - bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) y = y.view_as(y_org)