Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support finetuned lm_head and embed_tokens in LoRA adapters #909

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,7 @@ class LoRAConfig:
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
enable_lora_modules_to_save: bool = False

def __post_init__(self):
# Setting the maximum rank to 256 should be able to satisfy the vast
Expand Down
8 changes: 8 additions & 0 deletions aphrodite/engine/args_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class EngineArgs:
lora_dtype: str = "auto"
max_cpu_loras: Optional[int] = None
long_lora_scaling_factors: Optional[Tuple[float]] = None
enable_lora_modules_to_save: bool = False
fully_sharded_loras: bool = False
qlora_adapter_name_or_path: Optional[str] = None
enable_prompt_adapter: bool = False
Expand Down Expand Up @@ -832,6 +833,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=None,
help="Category: Adapter Options\n"
"Name or path of the LoRA adapter to use.")
parser.add_argument(
"--enable-lora-modules-to-save",
action="store_true",
help="Category: Adapter Options\n"
"If True, fully trained lm_head and embed_tokens "
"in LoRA will be used instead of A*B-style adapters.")
parser.add_argument('--enable-prompt-adapter',
action='store_true',
help='Category: Adapter Options\n'
Expand Down Expand Up @@ -1058,6 +1065,7 @@ def create_engine_config(self, ) -> EngineConfig:
lora_extra_vocab_size=self.lora_extra_vocab_size,
long_lora_scaling_factors=self.long_lora_scaling_factors,
lora_dtype=self.lora_dtype,
enable_lora_modules_to_save=self.enable_lora_modules_to_save,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None

Expand Down
196 changes: 173 additions & 23 deletions aphrodite/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from aphrodite.modeling.layers.rotary_embedding import (
LinearScalingRotaryEmbedding, RotaryEmbedding)
from aphrodite.modeling.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -64,6 +64,25 @@ def dec(*args, **kwargs):
return dec


class TensorPropertiesMixin:

@property
def dtype(self):
return self._dtype

@dtype.setter
def dtype(self, value):
self._dtype = value

@property
def device(self):
return self._device

@device.setter
def device(self, value):
self._device = value


@dataclass
class LoRAMapping(AdapterMapping):
is_prefill: bool = False
Expand Down Expand Up @@ -124,11 +143,13 @@ def can_replace_layer(
raise NotImplementedError


class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA, TensorPropertiesMixin):

def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.dtype = self.base_layer.weight.dtype
self.device = self.base_layer.weight.device
self.embeddings_slice: Optional[Tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor]

Expand All @@ -155,25 +176,20 @@ def create_lora_weights(
self.embeddings_slice = None
self.embeddings_weights = None

self.embeddings_tensors = torch.zeros(
(
max_loras,
lora_config.lora_extra_vocab_size,
self.base_layer.embedding_dim,
),
dtype=self.base_layer.weight.dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.org_vocab_size +
lora_config.lora_extra_vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.embeddings_tensors = torch.zeros((
max_loras,
lora_config.lora_extra_vocab_size,
self.base_layer.embedding_dim,
),
dtype=self.dtype,
device=self.device)
self.lora_a_stacked = torch.zeros((
max_loras,
self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size,
lora_config.max_lora_rank,
),
dtype=self.dtype,
device=self.device)
self.lora_b_stacked = torch.zeros(
(
max_loras,
Expand All @@ -182,7 +198,7 @@ def create_lora_weights(
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
device=self.device,
)
self.lora_a_stacked_2d = self.lora_a_stacked.view(
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
Expand Down Expand Up @@ -260,6 +276,9 @@ def can_replace_layer(
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
# do not use A*B-style LoRA, try to use modules_to_save
if lora_config.enable_lora_modules_to_save:
return False
return type(source_layer) is VocabParallelEmbedding


Expand Down Expand Up @@ -1010,7 +1029,7 @@ def can_replace_layer(
return type(source_layer) is RowParallelLinear


class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
class LogitsProcessorWithLoRA(BaseLayerWithLoRA, TensorPropertiesMixin):
"""
LoRA wrapper for LogitsProcessor, with extra logic to handle the
application of the LoRA adapter and added LoRA vocabulary.
Expand Down Expand Up @@ -1306,3 +1325,134 @@ def can_replace_layer(

def extra_repr(self) -> str:
return self.base_layer.extra_repr()


class ModulesToSaveWrapper(BaseLayerWithLoRA, TensorPropertiesMixin):
"""
LoRA wrapper for lm_head layer, inspired by ModulesToSaveWrapper from peft
contains the copy of base_layer but with replaced weights
overrides getattr in a such way that
returns the attribute of this base_layer copy,
so clients can call ModuleToSave exactly as base_layer module

Args:
base_layer: layer to replace by Wrapper:
VocabParallelEmbedding (for embed_tokens)
or ParallelLMHead (for lm_head)
"""

implemented_layers = ['lm_head', 'embed_tokens']

def __init__(
self, base_layer: Union[VocabParallelEmbedding,
ParallelLMHead]) -> None:
super().__init__()
self.base_layer = base_layer

self.device = _get_lora_device(self.base_layer)

self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()

@property
def padded_vocab_size(self):
# number of embeddings with paddings and with max_lora_extra_vocab_size
return self.base_layer.num_embeddings_padded

@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size

@property
def embedding_dim(self):
return self.base_layer.embedding_dim

@property
def bias(self):
return self.base_layer.bias

@property
def linear_method(self):
if self.punica_wrapper.no_lora:
return self.base_layer.linear_method

return self

@property
def weight(self):
return self.base_layer.weight

def apply(self, lm_head: 'ModulesToSaveWrapper',
hidden_states: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:

assert isinstance(self.base_layer, ParallelLMHead)

logits = self.punica_wrapper.bgmv_sample(hidden_states,
self._lora_tensors,
self.base_layer.weight)

if bias is not None:
logits += bias

return logits

def embedding(self, embed_tokens: 'ModulesToSaveWrapper',
masked_input: torch.LongTensor):
assert isinstance(self.base_layer, VocabParallelEmbedding)
embeddings = self.punica_wrapper.bgmv_embedding(
masked_input, self._lora_tensors, self.base_layer.weight)
return embeddings

def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:

self.dtype = lora_config.lora_dtype

# lora_tensors - lm_head tensors in case of ParallelLMHead base
# or embed_tokens tensors in case of VocabParallelEmbedding
self._lora_tensors = torch.zeros(
(max_loras, self.padded_vocab_size, self.base_layer.embedding_dim),
dtype=self.base_layer.weight.dtype,
device=self.device,
)
for index in range(max_loras):
self.reset_lora(index)

def reset_lora(self, index: int):
weights = self.base_layer.weight
self._lora_tensors[index, :weights.shape[0], :weights.shape[1]].copy_(
weights, non_blocking=True)

def set_lora(
self,
index: int,
lora_a: Optional[torch.Tensor],
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
assert lora_a is None
assert embeddings_tensor is None

self.reset_lora(index)
self._lora_tensors[index, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b, non_blocking=True)

def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs)

@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
if not lora_config.enable_lora_modules_to_save:
return False
return type(source_layer) in (ParallelLMHead, VocabParallelEmbedding)
Loading
Loading