Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LORA Implementation for Parameter Efficient Fine-Tuning on new datasets #159

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,37 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
)
},
)
use_peft: bool = field(
default=False,
metadata={
"help": "Flag to use parameter efficient fine-tuning, with LORA, for the decoder transformer. Default is without LORA"
},
)

lora_r: int = field(
default=8,
metadata={
"help": (
"The rank of the low-rank adaptation matrices in LORA. "
)
},
)

lora_alpha: float = field(
default=16.0,
metadata={
"help": (
"The scaling factor for the LORA updates. Controls strength of the adaptation "
)
},
)

lora_dropout: float = field(
default=0.05,
metadata={
"help": (
"The dropout rate applied to the LORA layers during training. "
)
},
)

84 changes: 84 additions & 0 deletions training/peft_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch.nn as nn
import torch
from tqdm import tqdm
import math

class LoRALinear(nn.Module):
def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout):
super().__init__()
self.linear = linear_layer

self.lora_r = lora_r
self.lora_alpha = lora_alpha
self.lora_dropout = nn.Dropout(p=lora_dropout)

self.lora_A = nn.Linear(linear_layer.in_features, lora_r, bias=False)
self.lora_B = nn.Linear(lora_r, linear_layer.out_features, bias=False)

nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) # following microsoft/LoRA
nn.init.zeros_(self.lora_B.weight)

self.scaling = self.lora_alpha / self.lora_r
self.linear.requires_grad_(False)

def forward(self, x):
out = self.linear(x) + torch.relu(self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling)
return out

def replace_linear_with_lora_old(model, lora_r, lora_alpha, lora_dropout):
for name, module in model.named_modules():
if any(item in name for item in ['embed_prompts', 'lm_heads']):
print('Ignored adding peft to ', name)
continue

if isinstance(module, nn.Linear):
lora_linear = LoRALinear(module, lora_r, lora_alpha, lora_dropout)
setattr(model, name, lora_linear)
return model

def replace_linear_with_lora(model, lora_r, lora_alpha, lora_dropout):
full_name_dict = {module: name for name, module in model.named_modules()}
linear_info = {}
modules = [model]
while len(modules) > 0:
submodule = modules.pop()
for name, raw_linear in submodule.named_children():
if isinstance(raw_linear, torch.nn.Linear):
full_name = full_name_dict[raw_linear]
linear_info[raw_linear] = {
"father": submodule,
"name": name,
"full_name": full_name,
}
else:
modules.append(raw_linear)

for total_len, _ in enumerate(model.named_modules()):
pass

i = 0
for name, module in tqdm(model.named_modules(), total=total_len, desc='Replacing Linear with Low-Rank Layers', mininterval=5):
if any(item in name for item in ['embed_prompts', 'lm_heads']):
print('Ignored adding peft to ', name)

elif module in linear_info:
info = linear_info[module]
new_module = LoRALinear(module, lora_r, lora_alpha, lora_dropout)
setattr(info["father"], info["name"], new_module)

del linear_info[module]
torch.cuda.empty_cache()

torch.cuda.empty_cache()
print('Replaced linear layers with low-rank layers.')
return model

def set_non_lora_gradients_to_false(model):
for name, param in model.named_parameters():
if "lora_" not in name:
param.requires_grad = False

if 'lm_heads' in name or 'embed_prompts' in name:
param.requires_grad = True
print("Using gradients for lm_heads or embed_prompts", name)
return model
15 changes: 15 additions & 0 deletions training/run_parler_tts_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@
from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
from training.eval import clap_similarity, wer, si_sdr
from training.peft_utils import replace_linear_with_lora, set_non_lora_gradients_to_false

logger = logging.getLogger(__name__)

os.environ["WANDB_MODE"] = "offline"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to do: remove this line


def main():
# See all possible arguments in src/transformers/training_args.py
Expand Down Expand Up @@ -333,6 +335,10 @@ def main():
attn_implementation=model_args.attn_implementation,
)

if training_args.use_peft == True:
logger.info('\n--Using PEFT, replacing layers--\n')
replace_linear_with_lora(model.decoder, lora_r=training_args.lora_r, lora_alpha=training_args.lora_alpha, lora_dropout=training_args.lora_dropout)

# enable gradient checkpointing if necessary
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
Expand Down Expand Up @@ -750,6 +756,8 @@ def compute_metrics(
# Prepare everything with accelerate
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)

print_trainable_params(model) # print trainable params, for lora

logger.info("***** Running training *****")
logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
Expand Down Expand Up @@ -1182,6 +1190,13 @@ def generate_step(batch, accelerator):

accelerator.end_training()

def print_trainable_params(model):
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
non_trainable = sum(p.numel() for p in model.parameters() if not p.requires_grad)
total_params = trainable + non_trainable
trainable_percent = 100 * trainable / total_params
print(f"Trainable: {trainable:,} | Non-trainable: {non_trainable:,} | Percent Trainable params: {trainable_percent:.2f}%")


if __name__ == "__main__":
main()