-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LORA Implementation for Parameter Efficient Fine-Tuning on new datasets #159
Open
sidhantls
wants to merge
17
commits into
huggingface:main
Choose a base branch
from
sidhantls:peft
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 13 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
2b74d39
utils for peft
sidhantls 4206a99
add basic peft
sidhantls 7f1a762
fix peft by init 0
sidhantls c85c0da
fix peft, init adapters zero
sidhantls 6d381d4
add peft training arg
sidhantls e6fc753
integrate use_peft commandline arg
sidhantls 3b4fc62
improve lora initialization
sidhantls 11ea917
add lora args
sidhantls 5b6ec15
fix use nn.init.kaiming_uniform_
sidhantls f9dd3d6
pass lora params from training args
sidhantls 7b21c2d
set gradient during lora replacment, do lora in one line
sidhantls 3cf84ea
dont set grads with separate fn rather use the grads set during lora …
sidhantls 714c23f
remove unused functions
sidhantls 2dc4875
add fn to convert lora to linear
sidhantls 81b2719
remove unused import
sidhantls 74f0c7e
add relu for training stability
sidhantls 07ca7c1
Merge branch 'peft' of https://github.com/sidhantls/parler-tts-fork i…
sidhantls File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import torch.nn as nn | ||
import torch | ||
from tqdm import tqdm | ||
import math | ||
|
||
class LoRALinear(nn.Module): | ||
def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): | ||
super().__init__() | ||
self.linear = linear_layer | ||
|
||
self.lora_r = lora_r | ||
self.lora_alpha = lora_alpha | ||
self.lora_dropout = nn.Dropout(p=lora_dropout) | ||
|
||
self.lora_A = nn.Linear(linear_layer.in_features, lora_r, bias=False) | ||
self.lora_B = nn.Linear(lora_r, linear_layer.out_features, bias=False) | ||
|
||
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) # following microsoft/LoRA | ||
nn.init.zeros_(self.lora_B.weight) | ||
|
||
self.scaling = self.lora_alpha / self.lora_r | ||
self.linear.requires_grad_(False) | ||
|
||
def forward(self, x): | ||
out = self.linear(x) + torch.relu(self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling) | ||
return out | ||
|
||
def replace_linear_with_lora_old(model, lora_r, lora_alpha, lora_dropout): | ||
for name, module in model.named_modules(): | ||
if any(item in name for item in ['embed_prompts', 'lm_heads']): | ||
print('Ignored adding peft to ', name) | ||
continue | ||
|
||
if isinstance(module, nn.Linear): | ||
lora_linear = LoRALinear(module, lora_r, lora_alpha, lora_dropout) | ||
setattr(model, name, lora_linear) | ||
return model | ||
|
||
def replace_linear_with_lora(model, lora_r, lora_alpha, lora_dropout): | ||
full_name_dict = {module: name for name, module in model.named_modules()} | ||
linear_info = {} | ||
modules = [model] | ||
while len(modules) > 0: | ||
submodule = modules.pop() | ||
for name, raw_linear in submodule.named_children(): | ||
if isinstance(raw_linear, torch.nn.Linear): | ||
full_name = full_name_dict[raw_linear] | ||
linear_info[raw_linear] = { | ||
"father": submodule, | ||
"name": name, | ||
"full_name": full_name, | ||
} | ||
else: | ||
modules.append(raw_linear) | ||
|
||
for total_len, _ in enumerate(model.named_modules()): | ||
pass | ||
|
||
i = 0 | ||
for name, module in tqdm(model.named_modules(), total=total_len, desc='Replacing Linear with Low-Rank Layers', mininterval=5): | ||
if any(item in name for item in ['embed_prompts', 'lm_heads']): | ||
print('Ignored adding peft to ', name) | ||
|
||
elif module in linear_info: | ||
info = linear_info[module] | ||
new_module = LoRALinear(module, lora_r, lora_alpha, lora_dropout) | ||
setattr(info["father"], info["name"], new_module) | ||
|
||
del linear_info[module] | ||
torch.cuda.empty_cache() | ||
|
||
torch.cuda.empty_cache() | ||
print('Replaced linear layers with low-rank layers.') | ||
return model | ||
|
||
def set_non_lora_gradients_to_false(model): | ||
for name, param in model.named_parameters(): | ||
if "lora_" not in name: | ||
param.requires_grad = False | ||
|
||
if 'lm_heads' in name or 'embed_prompts' in name: | ||
param.requires_grad = True | ||
print("Using gradients for lm_heads or embed_prompts", name) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to do: remove this line