Skip to content

Commit

Permalink
doneall the implmemtation LoRa
Browse files Browse the repository at this point in the history
  • Loading branch information
yunss-ML committed Aug 26, 2023
1 parent ad9e851 commit c4a717b
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 157 deletions.
54 changes: 34 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,50 @@
</div>


### Features
## Features

**Distinguishing the "loralib" and "loratorch" Approaches for Implementation**
- **LoRALib Approach**: This approach involves calculating the computations `xW_0^T` and `x(BA)^T` separately, followed by their summation. This approach is particularly suitable for linear layers and offers accurate computation of LoRA-enhanced layers.

The implementations of "loralib" and "loratorch" exhibit distinct methodologies, particularly when using the example of `nn.Linear`. The underlying mathematical representations are as follows:
- **LoRATorch Approach**: In this approach, the pre-trained weight `W_0` is merged with its LoRA weight `BA`, resulting in the combined weight matrix `(W_0 + \frac{\alpha}{r} BA)`. This approach allows for the straightforward extension of LoRA to more complex and non-linear layers within the PyTorch ecosystem.

1. **LoRa** Approaches
## Mathematical Formulation

The computation is defined as:
1. **LoRALib Approach**:

$h = x W_0^\top + \frac{\alpha}{r} x(BA)^\top,$
The computation is defined as:

$\( h = xW_0^T + \frac{\alpha}{r} x(BA)^T \)$

$where:
- `x` is an input matrix of dimensions \(k \times n\),
- `W_0` is a pre-trained weight matrix of dimensions \(m \times n\),
- `r` is a predefined LoRA rank,
- `B` and `A` are LoRA matrices of dimensions \(m \times r\) and \(r \times n\) respectively,
- `\alpha` is a hyper-parameter.$
$where:
- \( x \) is the input matrix of dimensions \( k \times n \),
- \( W_0 \) is a pre-trained weight matrix of dimensions \( m \times n \),
- \( r \) is a predefined LoRA rank,
- \( B \) and \( A \) are LoRA matrices of dimensions \( m \times r \) and \( r \times n \) respectively,
- \( \alpha \) is a hyper-parameter.$

2. **LoRATorch Approach**:

1. For ``loralib``,
$h = x W_0^\top + \frac{\alpha}{r} x(BA)^\top,$
The computation is defined as:

$\( h = x(W_0 + \frac{\alpha}{r} BA)^T \)$

$where:
- \( x \) is the input matrix of dimensions \( k \times n \),
- \( W_0 \) is a pre-trained weight matrix of dimensions \( m \times n \),
- \( r \) is a predefined LoRA rank,
- \( B \) and \( A \) are LoRA matrices of dimensions \( m \times r \) and \( r \times n \) respectively,
- \( \alpha \) is a hyper-parameter.$

where $x\in\mathbb{R}^{k\times n}$ is the input matrix, $W_0\in\mathbb{R}^{m\times n}$ is the pre-trained weight matrix, $r$ is the predefined LoRA rank, $B\in\mathbb{R}^{m\times r}$ and $A\in \mathbb{R}^{r\times n}$ are the LoRA matrixes, and $\alpha$ is a hyper-parameter.
## Usage

2. For ``loratorch``,
$h = x (W_0 + \frac{\alpha}{r} BA)^\top.$

``loralib`` computes $xW_0^\top$ and $x(BA)^\top$ respectively and then merges the results.
While ``loratorch`` merges pre-trained weight $W_0$ and its LoRA weight $BA$ and then computes the results by simply using ``nn.Linear.forward()``. There is no difference between ``loralib`` and ``loratorch`` in the linear layers. But in some no-linear or complex layers, we are no sure whether this layer satisfies $L(x, W_0)+L(x, BA) = L(x, W_0+BA)$. Hence, it is difficult to extend LoRA to some complex layers by using ``loralib``. On the contrary, the idea of merging weights first in ``loratorch`` is more general and extensible. You just call ``merge_lora_param()`` in ``loratorch`` to merge weights and then call ``forward()`` in the original layer to compute the results. With the help of ``loratorch``, you can easily implement LoRA to any type of layer of ``torch.nn``.
1. **AdapterLoRa Class**: The `AdapterLoRa` class provides a versatile interface for applying LoRA adaptation to neural networks. It supports both `loralib` and `loratorch` approaches, offering the ability to reconstruct and implement LoRA-adapted models.

2. **Adapting Layers**: The `add_layer_and_Instance_Layer` method allows you to specify the layers you want to adapt using the `layertyep` and `layer` parameters. This method helps tailor the LoRA application to specific layers in your model.

3. **Freezing Weights**: The `freeze_weights` method enables the option to freeze model weights, enhancing stability and allowing for safer adaptations.

4. **Reconstructing and Implementing LoRA**: The `reconstruct_model` method applies LoRA adaptation to the model, while the `implement_lora` method further implements LoRA and manages trainable parameters.
.

## Supported Layers

Expand Down
161 changes: 103 additions & 58 deletions core/LayersAdaptes.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,135 @@
import loralib as LoRa
import loratorch as LoRaT
import torch.nn as nn
import bitsandbytes as bnb
import loralib as LoRa
import loratorch as LoRaT
from typing import Optional
import bitsandbytes as nn

from .Quantized import AdapterLoRa

LAYERS = AdapterLoRa.layertyep

def Layer(model, new_layer):
"""
Copy weights and biases from the original layer to the new layer.
Args:
model (nn.Module): The original layer.
new_layer (nn.Module): The new layer.
Returns:
nn.Module: The new layer with copied weights and biases.
"""
new_layer.weight = nn.Parameter(model.weight.detach().clone())

if model.bias is not None:
new_layer.bias = nn.Parameter(model.bias.detach().clone())

return new_layer

def LoRaLinear(method:str, model:nn.Module, Rank:Optional[int],threshold:Optional[int]):
Adapters = ["LoRa","SandBytes","LoRaTorch"]
if Adapters.__contains__(Adapters) == True:
@Adapters(layertyep)
def LoRaLinear(method: str, model: nn.Module, Rank: Optional[int], threshold: Optional[int]):
"""
Replace a linear layer with a quantized layer using specified method.
Args:
method (str): The quantization method ("LoRa", "SandBytes", "LoRaTorch").
model (nn.Module): The input model containing the linear layer.
Rank (Optional[int]): The rank parameter for LoRA adaptation.
threshold (Optional[int]): The threshold parameter for SandBytes adaptation.
Returns:
nn.Module: The modified model with the quantized layer.
"""
Adapters = ["LoRa", "SandBytes", "LoRaTorch"]

if method in Adapters:
if method == "LoRa":
new_layer = LoRa.Linear(
in_features=model.in_features,
out_features=model.out_features,
bias=model.bias is not None,
r=Rank
in_features=model.in_features,
out_features=model.out_features,
bias=model.bias is not None,
r=Rank
)
return Layer(model . new_layer)
return Layer(model, new_layer)

if method == "SandBytes":
new_layer = bnb.nn.Linear8bitLt(
model.in_features,
model.out_featuresm2,
bias=model.bias is not None,
has_fp16_weights=False,
threshold=6.0
)
return Layer(model . new_layer)

model.out_features,
bias=model.bias is not None,
has_fp16_weights=False,
threshold=threshold
)
return Layer(model, new_layer)

if method == "LoRaTorch":
if method == "LoRaTorch":
new_layer = LoRaT.Linear(
in_features=model.in_features,
out_features=model.out_features,
bias=model.bias is not None,
r=Rank
)
return Layer(model . new_layer)
in_features=model.in_features,
out_features=model.out_features,
bias=model.bias is not None,
r=Rank
)
return Layer(model, new_layer)

else:
raise ValueError(f"there's no method support yet or may you inster invalide name method {method}")


def LoRaEmbedding(method:str,
model:nn.Module ,
Rank:Optional[int],
lora_alpha:Optional[int],
scale_grad_by_freq:Optional[int],
padding_idx:Optional[int],
max_norm:Optional[int]):

Adapters = ["LoRa","SandBytes","LoRaTorch"]
if Adapters.__contains__(Adapters) == True:
raise ValueError(f"Unsupported method or invalid method name: {method}")

@Adapters(layertyep)
def LoRaEmbedding(
method: str,
model: nn.Module,
Rank: Optional[int],
lora_alpha: Optional[int],
scale_grad_by_freq: Optional[int],
padding_idx: Optional[int],
max_norm: Optional[int]
):
"""
Replace an embedding layer with a quantized layer using specified method.
Args:
method (str): The quantization method ("LoRa", "SandBytes", "LoRaTorch").
model (nn.Module): The input model containing the embedding layer.
Rank (Optional[int]): The rank parameter for LoRA adaptation.
lora_alpha (Optional[int]): The alpha parameter for LoRA adaptation.
scale_grad_by_freq (Optional[int]): The scale_grad_by_freq parameter for LoRA adaptation.
padding_idx (Optional[int]): The padding_idx parameter for LoRA adaptation.
max_norm (Optional[int]): The max_norm parameter for LoRA adaptation.
Returns:
nn.Module: The modified model with the quantized layer.
"""
Adapters = ["LoRa", "SandBytes", "LoRaTorch"]

if method in Adapters:
if method == "LoRa":
new_layer = LoRa.Embedding(model.num_embeddings,
model.embedding_dim,
r=Rank,
lora_alpha=lora_alpha,
max_norm=model.max_norm is not None,
scale_grad_by_freq=model.scale_grad_by_freq is not None,
padding_idx=model.padding_idx is not None
)
new_layer = LoRa.Embedding(
model.num_embeddings,
model.embedding_dim,
r=Rank,
lora_alpha=lora_alpha,
max_norm=model.max_norm is not None,
scale_grad_by_freq=model.scale_grad_by_freq is not None,
padding_idx=model.padding_idx is not None
)
return new_layer

if method == "SandBytes":
new_layer= bnb.nn.StableEmbedding(model.num_embeddings,
model.embedding_dim )
new_layer = bnb.nn.StableEmbedding(
model.num_embeddings,
model.embedding_dim
)
return new_layer

if method == "LoRaTorch":
new_layer = LoRaT.Embedding(model.num_embeddings,
model.embedding_dim,
r=Rank,
max_norm=model.max_norm is not None,
scale_grad_by_freq=model.scale_grad_by_freq is not None,
padding_idx=model.padding_idx is not None
)
new_layer = LoRaT.Embedding(
model.num_embeddings,
model.embedding_dim,
r=Rank,
max_norm=model.max_norm is not None,
scale_grad_by_freq=model.scale_grad_by_freq is not None,
padding_idx=model.padding_idx is not None
)
return new_layer
else:
raise ValueError(f"there's no method support yet or may you inster invalide name method {method}")

else:
raise ValueError(f"Unsupported method or invalid method name: {method}")
Loading

0 comments on commit c4a717b

Please sign in to comment.