diff --git a/training/arguments.py b/training/arguments.py index acc3f9b..5f1bdc0 100644 --- a/training/arguments.py +++ b/training/arguments.py @@ -360,3 +360,37 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): ) }, ) + use_peft: bool = field( + default=False, + metadata={ + "help": "Flag to use parameter efficient fine-tuning, with LORA, for the decoder transformer. Default is without LORA" + }, + ) + + lora_r: int = field( + default=8, + metadata={ + "help": ( + "The rank of the low-rank adaptation matrices in LORA. " + ) + }, + ) + + lora_alpha: float = field( + default=16.0, + metadata={ + "help": ( + "The scaling factor for the LORA updates. Controls strength of the adaptation " + ) + }, + ) + + lora_dropout: float = field( + default=0.05, + metadata={ + "help": ( + "The dropout rate applied to the LORA layers during training. " + ) + }, + ) + diff --git a/training/peft_utils.py b/training/peft_utils.py new file mode 100644 index 0000000..bfe14e3 --- /dev/null +++ b/training/peft_utils.py @@ -0,0 +1,119 @@ +import torch.nn as nn +import torch +from tqdm import tqdm +import math + +class LoRALinear(nn.Module): + def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): + super().__init__() + self.linear = linear_layer + + self.lora_r = lora_r + self.lora_alpha = lora_alpha + self.lora_dropout = nn.Dropout(p=lora_dropout) + + self.lora_A = nn.Linear(linear_layer.in_features, lora_r, bias=False) + self.lora_B = nn.Linear(lora_r, linear_layer.out_features, bias=False) + + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) # following microsoft/LoRA + nn.init.zeros_(self.lora_B.weight) + + self.scaling = self.lora_alpha / self.lora_r + self.linear.requires_grad_(False) + + def forward(self, x): + out = self.linear(x) + torch.relu(self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling) + return out + + def return_weights_without_lora(self): + """ + After adapters has been trained, this functions returns the final linear weights after original linear layer is + combined with trained peft adapters + """ + with torch.no_grad(): + new_weight = self.linear.weight + (self.lora_B.weight @ self.lora_A.weight) * self.scaling + return new_weight + +def replace_linear_with_lora(model, lora_r, lora_alpha, lora_dropout): + """ + Given a model, replaces all linear layers with a Linear LORA layer in-place. Returns model + """ + full_name_dict = {module: name for name, module in model.named_modules()} + linear_info = {} + modules = [model] + while len(modules) > 0: + submodule = modules.pop() + for name, raw_linear in submodule.named_children(): + if isinstance(raw_linear, torch.nn.Linear): + full_name = full_name_dict[raw_linear] + linear_info[raw_linear] = { + "father": submodule, + "name": name, + "full_name": full_name, + } + else: + modules.append(raw_linear) + + for total_len, _ in enumerate(model.named_modules()): + pass + + i = 0 + for name, module in tqdm(model.named_modules(), total=total_len, desc='Replacing Linear with Low-Rank Layers', mininterval=5): + if any(item in name for item in ['embed_prompts', 'lm_heads']): + print('Ignored adding peft to ', name) + + elif module in linear_info: + info = linear_info[module] + new_module = LoRALinear(module, lora_r, lora_alpha, lora_dropout) + setattr(info["father"], info["name"], new_module) + + del linear_info[module] + torch.cuda.empty_cache() + + torch.cuda.empty_cache() + print('Replaced linear layers with low-rank layers.') + return model + + +def replace_lora_with_linear(model): + """ + Given a model that has LoRa adapters, this function replaces the trained Lora Linear layers into a regular nn.Linear + layer. This is done before saving the model, to remove LoRA adapters before model saving + """ + full_name_dict = {module: name for name, module in model.named_modules()} + linear_info = {} + modules = [model] + while len(modules) > 0: + submodule = modules.pop() + for name, raw_linear in submodule.named_children(): + if isinstance(raw_linear, LoRALinear): + full_name = full_name_dict[raw_linear] + linear_info[raw_linear] = { + "father": submodule, + "name": name, + "full_name": full_name, + } + else: + modules.append(raw_linear) + + for total_len, _ in enumerate(model.named_modules()): + pass + + i = 0 + for name, module in tqdm(model.named_modules(), total=total_len, desc='Removing LoRA layers', mininterval=5): + if module in linear_info: + info = linear_info[module] + + weight = module.linear.weight.data.clone() # Shape: [out_features, in_features] + new_linear_weight = module.return_weights_without_lora() + + new_module = nn.Linear(module.linear.in_features, module.linear.out_features) + new_module.weight.data = new_linear_weight + + setattr(info["father"], info["name"], new_module) + + del linear_info[module] + torch.cuda.empty_cache() + + print('Replaced linear layers with low-rank layers.') + return model diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index e4cd2da..503ba26 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -64,9 +64,11 @@ from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding from training.eval import clap_similarity, wer, si_sdr +from training.peft_utils import replace_linear_with_lora, replace_lora_with_linear logger = logging.getLogger(__name__) +os.environ["WANDB_MODE"] = "offline" def main(): # See all possible arguments in src/transformers/training_args.py @@ -333,6 +335,10 @@ def main(): attn_implementation=model_args.attn_implementation, ) + if training_args.use_peft == True: + logger.info('\n--Using PEFT, replacing layers--\n') + replace_linear_with_lora(model.decoder, lora_r=training_args.lora_r, lora_alpha=training_args.lora_alpha, lora_dropout=training_args.lora_dropout) + # enable gradient checkpointing if necessary if training_args.gradient_checkpointing: model.gradient_checkpointing_enable() @@ -750,6 +756,8 @@ def compute_metrics( # Prepare everything with accelerate model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + print_trainable_params(model) # print trainable params, for lora + logger.info("***** Running training *****") logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}") logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}") @@ -1182,6 +1190,13 @@ def generate_step(batch, accelerator): accelerator.end_training() +def print_trainable_params(model): + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + non_trainable = sum(p.numel() for p in model.parameters() if not p.requires_grad) + total_params = trainable + non_trainable + trainable_percent = 100 * trainable / total_params + print(f"Trainable: {trainable:,} | Non-trainable: {non_trainable:,} | Percent Trainable params: {trainable_percent:.2f}%") + if __name__ == "__main__": main()