-
Notifications
You must be signed in to change notification settings - Fork 0
/
__init__.py
151 lines (116 loc) · 5.29 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
---
title: Low-Rank Adaptation (LoRA)
summary: >
Annotated implementation of RoRA from paper
LoRA: Low-Rank Adaptation of Large Language Models
---
# Low-Rank Adaptation (LoRA)
This is an implementation of
[Low-Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685)
in [PyTorch](https://pytorch.org).
Low-Rank Adaptation (LoRA) freezes pre-trained model weights and injects
trainable rank decomposition matrices into each layer of the transformer.
This makes it possible to efficiently fine-tune large langauge models by
reducing trainable parameters by a large factor.
Here's [the training code](experiment.html) for training a GPT2 model with LoRA
on Tiny Shakespeare dataset.
"""
import torch
import torch.nn as nn
class Linear(nn.Module):
"""
## LoRA Linear Layer
LoRA linear layer adds a low-rank decomposition to the pre-trained
weight matrix ($W_0 \in \mathbb{R}^{d \times k}$)
of the linear layer.
$$W_0 + \Delta W = W_0 + BA$$
, where $B \in \mathbb{R}^{d \times r}$, $A \in \mathbb{R}^{r \times k}$,
and the rank $r \ll min(d, k)$.
All parameters are frozen except $A$ and $B$.
$\Delta W$ is initialized to be zero at the beginning of the training.
They multiple $x \Delta W^T$ by $\frac{\alpha}{r}$ where $\alpha$ is a hyper-parameter.
Once $\alpha$ is tuned it can be kept the same when varying $r$.
"""
def __init__(self, in_features: int, out_features: int, bias: bool,
r: int, alpha: int = None):
"""
:param in_features: is the number of input features of the linear layer
:param out_features: is the number of output features of the linear layer
:param bias: is a flag indicating if there is a bias parameter
:param r: is the rank of the decomposition $r$
:param alpha: is the scaling factor $\alpha$
"""
super().__init__()
# Set $\alpha = r$ is not provided. i.e. make the scaling factor $\frac{\alpha}{r} = 1$.
if alpha is None:
alpha = r
# The pre-trained weight $W_0$
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
# Freeze it
self.weight.requires_grad = False
if bias:
# Bias parameter $b_0$ (also frozen)
self.bias = nn.Parameter(torch.empty(out_features))
self.bias.requires_grad = False
else:
# No bias parameter
self.bias = None
# scaling factor $\frac{\alpha}{r}$
self.scaling = alpha / r
# Matrix $A \in \mathbb{R}^{r \times k}$
self.lora_a = nn.Parameter(torch.empty((r, in_features)))
# Matrix $B \in \mathbb{R}^{d \times r}$, we keep $A$ and $B$ transposed
self.lora_b = nn.Parameter(torch.empty((out_features, r)))
with torch.no_grad():
# Initialize $A$ similar to a weight matrix in a normal linear layer
nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
# Initialize $B$ to $0$ so that $\Delta W = BA$ is $0$ at initialization
nn.init.zeros_(self.lora_b)
def forward(self, x: torch.Tensor):
# Compute $x W_0^T + b_0$
result = nn.functional.linear(x, self.weight, bias=self.bias)
# Add $\frac{\alpha}{r} x \Delta W^T = \frac{\alpha}{r} x {(BA)}^T = \frac{\alpha}{r} x A^T B^T$
result += (x @ self.lora_a.T @ self.lora_b.T) * self.scaling
#
return result
class Embedding(nn.Module):
"""
## LoRA Embedding Layer
Similar to LoRA linear layer this adds a low-rank decomposition to the pre-trained
embedding weights matrix ($W_0 \in \mathbb{R}^{d \times k}$).
$$W_0 + \Delta W = W_0 + BA$$
"""
def __init__(self, num_embeddings: int, embedding_dim: int,
r: int, alpha: int = None):
"""
:param num_embeddings: is the number of embeddings
:param embedding_dim: is the number embedding dimensions
:param r: is the rank of the decomposition $r$
:param alpha: is the scaling factor $\alpha$
"""
super().__init__()
# Set $\alpha = r$ is not provided. i.e. make the scaling factor $\frac{\alpha}{r} = 1$.
if alpha is None:
alpha = r
# The pre-trained embedding weights $W_0^T$ (frozen)
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
self.weight.requires_grad = False
# scaling factor $\frac{\alpha}{r}$
self.scaling = alpha / r
# Matrix $A \in \mathbb{R}^{r \times k}$
self.lora_a = nn.Parameter(torch.empty((r, num_embeddings)))
# Matrix $B \in \mathbb{R}^{d \times r}$
self.lora_b = nn.Parameter(torch.empty((embedding_dim, r)))
with torch.no_grad():
# Initialize $A$ with a normal distribution
nn.init.normal_(self.lora_a)
# Initialize $B$ to $0$ so that $\Delta W = BA$ is $0$ at initialization
nn.init.zeros_(self.lora_b)
def forward(self, x: torch.Tensor):
# Compute the embeddings $\text{onehot}(x) W_0$
result = nn.functional.embedding(x, self.weight)
# Add $\frac{\alpha}{r} \text{onehot}(x) \Delta W^T = \frac{\alpha}{r} \text{onehot}(x) A^T B^T$
result += (nn.functional.embedding(x, self.lora_a.T) @ self.lora_b.T) * self.scaling
#
return result