From 2b74d397380a62e43731cb7076b01b30b8f00d10 Mon Sep 17 00:00:00 2001 From: Sidhant Date: Fri, 25 Oct 2024 18:51:48 +0530 Subject: [PATCH 01/16] utils for peft --- training/peft_utils.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 training/peft_utils.py diff --git a/training/peft_utils.py b/training/peft_utils.py new file mode 100644 index 0000000..22bf6d1 --- /dev/null +++ b/training/peft_utils.py @@ -0,0 +1,39 @@ +import torch.nn as nn + +class LoRALinear(nn.Module): + def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): + super().__init__() + self.linear = linear_layer + self.lora_r = lora_r + self.lora_alpha = lora_alpha + self.lora_dropout = nn.Dropout(p=lora_dropout) + + self.lora_A = nn.Linear(linear_layer.in_features, lora_r, bias=False) + self.lora_B = nn.Linear(lora_r, linear_layer.out_features, bias=False) + self.scaling = self.lora_alpha / self.lora_r + + def forward(self, x): + x = self.lora_dropout(x) + x = self.linear(x) + self.lora_B(self.lora_A(x)) * self.scaling + return x + +def replace_linear_with_lora(model, lora_r, lora_alpha, lora_dropout): + for name, module in model.named_modules(): + if any(item in name for item in ['embed_prompts', 'lm_heads']): + print('Ignored adding peft to ', name) + continue + + if isinstance(module, nn.Linear): + lora_linear = LoRALinear(module, lora_r, lora_alpha, lora_dropout) + setattr(model, name, lora_linear) + return model + +def set_non_lora_gradients_to_false(model): + for name, param in model.named_parameters(): + if "lora_" not in name: + param.requires_grad = False + + if 'lm_heads' in name or 'embed_prompts' in name: + param.requires_grad = True + print("Using gradients for lm_heads or embed_prompts", name) + return model \ No newline at end of file From 4206a9991af01a736d414b79476fd254a7315bd0 Mon Sep 17 00:00:00 2001 From: Sidhant Date: Fri, 25 Oct 2024 18:52:34 +0530 Subject: [PATCH 02/16] add basic peft --- training/run_parler_tts_training.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index e4cd2da..2d2bcc0 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -64,6 +64,7 @@ from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding from training.eval import clap_similarity, wer, si_sdr +from training.peft_utils import replace_linear_with_lora, set_non_lora_gradients_to_false logger = logging.getLogger(__name__) @@ -333,6 +334,11 @@ def main(): attn_implementation=model_args.attn_implementation, ) + do_peft = True + if do_peft: + replace_linear_with_lora(model.decoder, lora_r=16, lora_alpha=32, lora_dropout=0.05) + set_non_lora_gradients_to_false(model.decoder) + # enable gradient checkpointing if necessary if training_args.gradient_checkpointing: model.gradient_checkpointing_enable() From 7f1a7627fae56387b5e118487922463712562759 Mon Sep 17 00:00:00 2001 From: sidhantls Date: Sun, 27 Oct 2024 15:25:47 +0000 Subject: [PATCH 03/16] fix peft by init 0 --- training/peft_utils.py | 47 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/training/peft_utils.py b/training/peft_utils.py index 22bf6d1..3f48715 100644 --- a/training/peft_utils.py +++ b/training/peft_utils.py @@ -1,4 +1,6 @@ import torch.nn as nn +import torch +from tqdm import tqdm class LoRALinear(nn.Module): def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): @@ -10,6 +12,10 @@ def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): self.lora_A = nn.Linear(linear_layer.in_features, lora_r, bias=False) self.lora_B = nn.Linear(lora_r, linear_layer.out_features, bias=False) + + init.zeros_(self.lora_A.weight) + init.zeros_(self.lora_B.weight) + self.scaling = self.lora_alpha / self.lora_r def forward(self, x): @@ -17,7 +23,7 @@ def forward(self, x): x = self.linear(x) + self.lora_B(self.lora_A(x)) * self.scaling return x -def replace_linear_with_lora(model, lora_r, lora_alpha, lora_dropout): +def replace_linear_with_lora_old(model, lora_r, lora_alpha, lora_dropout): for name, module in model.named_modules(): if any(item in name for item in ['embed_prompts', 'lm_heads']): print('Ignored adding peft to ', name) @@ -28,6 +34,43 @@ def replace_linear_with_lora(model, lora_r, lora_alpha, lora_dropout): setattr(model, name, lora_linear) return model +def replace_linear_with_lora(model, lora_r, lora_alpha, lora_dropout): + full_name_dict = {module: name for name, module in model.named_modules()} + linear_info = {} + modules = [model] + while len(modules) > 0: + submodule = modules.pop() + for name, raw_linear in submodule.named_children(): + if isinstance(raw_linear, torch.nn.Linear): + full_name = full_name_dict[raw_linear] + linear_info[raw_linear] = { + "father": submodule, + "name": name, + "full_name": full_name, + } + else: + modules.append(raw_linear) + + for total_len, _ in enumerate(model.named_modules()): + pass + + i = 0 + for name, module in tqdm(model.named_modules(), total=total_len, desc='Replacing Linear with Low-Rank Layers', mininterval=5): + if any(item in name for item in ['embed_prompts', 'lm_heads']): + print('Ignored adding peft to ', name) + + elif module in linear_info: + info = linear_info[module] + new_module = LoRALinear(module, lora_r, lora_alpha, lora_dropout) + setattr(info["father"], info["name"], new_module) + + del linear_info[module] + torch.cuda.empty_cache() + + torch.cuda.empty_cache() + print('Replaced linear layers with low-rank layers.') + return model + def set_non_lora_gradients_to_false(model): for name, param in model.named_parameters(): if "lora_" not in name: @@ -36,4 +79,4 @@ def set_non_lora_gradients_to_false(model): if 'lm_heads' in name or 'embed_prompts' in name: param.requires_grad = True print("Using gradients for lm_heads or embed_prompts", name) - return model \ No newline at end of file + return model From c85c0da1e52e07756a77752a43c3613dd583be16 Mon Sep 17 00:00:00 2001 From: sidhantls Date: Sun, 27 Oct 2024 19:03:22 +0000 Subject: [PATCH 04/16] fix peft, init adapters zero --- training/peft_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/training/peft_utils.py b/training/peft_utils.py index 3f48715..d792bc7 100644 --- a/training/peft_utils.py +++ b/training/peft_utils.py @@ -13,8 +13,8 @@ def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): self.lora_A = nn.Linear(linear_layer.in_features, lora_r, bias=False) self.lora_B = nn.Linear(lora_r, linear_layer.out_features, bias=False) - init.zeros_(self.lora_A.weight) - init.zeros_(self.lora_B.weight) + nn.init.zeros_(self.lora_A.weight) + nn.init.zeros_(self.lora_B.weight) self.scaling = self.lora_alpha / self.lora_r From 6d381d4e23c779d3e25375b00130f4ffe6d72b43 Mon Sep 17 00:00:00 2001 From: sidhantls Date: Sun, 27 Oct 2024 19:03:38 +0000 Subject: [PATCH 05/16] add peft training arg --- training/arguments.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/training/arguments.py b/training/arguments.py index acc3f9b..847d214 100644 --- a/training/arguments.py +++ b/training/arguments.py @@ -360,3 +360,9 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): ) }, ) + use_peft: bool = field( + default=False, + metadata={ + "help": "Flag to use parameter efficient fine-tuning, with LORA, for the decoder transformer. Default is without LORA" + }, + ) From e6fc753583535f7840ae8e3553f7b0e09e58089f Mon Sep 17 00:00:00 2001 From: sidhantls Date: Sun, 27 Oct 2024 19:04:02 +0000 Subject: [PATCH 06/16] integrate use_peft commandline arg --- training/run_parler_tts_training.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 2d2bcc0..f12fad5 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -68,6 +68,7 @@ logger = logging.getLogger(__name__) +os.environ["WANDB_MODE"] = "offline" def main(): # See all possible arguments in src/transformers/training_args.py @@ -334,10 +335,11 @@ def main(): attn_implementation=model_args.attn_implementation, ) - do_peft = True - if do_peft: - replace_linear_with_lora(model.decoder, lora_r=16, lora_alpha=32, lora_dropout=0.05) + if training_args.use_peft == True: + replace_linear_with_lora(model.decoder, lora_r=8, lora_alpha=16, lora_dropout=0.05) set_non_lora_gradients_to_false(model.decoder) + + print_trainable_params(model) # enable gradient checkpointing if necessary if training_args.gradient_checkpointing: @@ -1188,6 +1190,13 @@ def generate_step(batch, accelerator): accelerator.end_training() +def print_trainable_params(model): + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + non_trainable = sum(p.numel() for p in model.parameters() if not p.requires_grad) + total_params = trainable + non_trainable + trainable_percent = 100 * trainable / total_params + print(f"Trainable: {trainable:,} | Non-trainable: {non_trainable:,} | Percent Trainable params: {trainable_percent:.2f}%") + if __name__ == "__main__": main() From 3b4fc620563d14c7fe1be21c2f57ca3c49bc7ad3 Mon Sep 17 00:00:00 2001 From: sidhantls Date: Sun, 27 Oct 2024 19:24:21 +0000 Subject: [PATCH 07/16] improve lora initialization --- training/peft_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/training/peft_utils.py b/training/peft_utils.py index d792bc7..c5f2409 100644 --- a/training/peft_utils.py +++ b/training/peft_utils.py @@ -1,6 +1,7 @@ import torch.nn as nn import torch from tqdm import tqdm +import math class LoRALinear(nn.Module): def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): @@ -13,15 +14,15 @@ def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): self.lora_A = nn.Linear(linear_layer.in_features, lora_r, bias=False) self.lora_B = nn.Linear(lora_r, linear_layer.out_features, bias=False) - nn.init.zeros_(self.lora_A.weight) + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) # following microsoft/LoRA nn.init.zeros_(self.lora_B.weight) self.scaling = self.lora_alpha / self.lora_r def forward(self, x): - x = self.lora_dropout(x) - x = self.linear(x) + self.lora_B(self.lora_A(x)) * self.scaling - return x + out = self.linear(x) + out = out + self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + return out def replace_linear_with_lora_old(model, lora_r, lora_alpha, lora_dropout): for name, module in model.named_modules(): From 11ea91766b5b6fc600159d47dd85984cb3f003b6 Mon Sep 17 00:00:00 2001 From: sidhantls Date: Mon, 28 Oct 2024 19:56:01 +0000 Subject: [PATCH 08/16] add lora args --- training/arguments.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/training/arguments.py b/training/arguments.py index 847d214..5f1bdc0 100644 --- a/training/arguments.py +++ b/training/arguments.py @@ -366,3 +366,31 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): "help": "Flag to use parameter efficient fine-tuning, with LORA, for the decoder transformer. Default is without LORA" }, ) + + lora_r: int = field( + default=8, + metadata={ + "help": ( + "The rank of the low-rank adaptation matrices in LORA. " + ) + }, + ) + + lora_alpha: float = field( + default=16.0, + metadata={ + "help": ( + "The scaling factor for the LORA updates. Controls strength of the adaptation " + ) + }, + ) + + lora_dropout: float = field( + default=0.05, + metadata={ + "help": ( + "The dropout rate applied to the LORA layers during training. " + ) + }, + ) + From 5b6ec158c9daaa9caf290828cedae9f8c30d0254 Mon Sep 17 00:00:00 2001 From: sidhantls Date: Mon, 28 Oct 2024 19:56:25 +0000 Subject: [PATCH 09/16] fix use nn.init.kaiming_uniform_ --- training/peft_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/training/peft_utils.py b/training/peft_utils.py index c5f2409..8435fb0 100644 --- a/training/peft_utils.py +++ b/training/peft_utils.py @@ -14,8 +14,9 @@ def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): self.lora_A = nn.Linear(linear_layer.in_features, lora_r, bias=False) self.lora_B = nn.Linear(lora_r, linear_layer.out_features, bias=False) - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) # following microsoft/LoRA + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) # following microsoft/LoRA nn.init.zeros_(self.lora_B.weight) + #nn.init.zeros_(self.lora_A.weight) self.scaling = self.lora_alpha / self.lora_r From f9dd3d642f0b4235162bc7666b91f1c3105f21f9 Mon Sep 17 00:00:00 2001 From: sidhantls Date: Mon, 28 Oct 2024 19:56:44 +0000 Subject: [PATCH 10/16] pass lora params from training args --- training/run_parler_tts_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index f12fad5..7d3f7eb 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -336,7 +336,7 @@ def main(): ) if training_args.use_peft == True: - replace_linear_with_lora(model.decoder, lora_r=8, lora_alpha=16, lora_dropout=0.05) + replace_linear_with_lora(model.decoder, lora_r=training_args.lora_r, lora_alpha=training_args.lora_alpha, lora_dropout=training_args.lora_dropout) set_non_lora_gradients_to_false(model.decoder) print_trainable_params(model) From 7b21c2d65848f91f6b536eef3e3cc475646926d4 Mon Sep 17 00:00:00 2001 From: sidhant ls Date: Tue, 29 Oct 2024 14:10:02 +0530 Subject: [PATCH 11/16] set gradient during lora replacment, do lora in one line --- training/peft_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/training/peft_utils.py b/training/peft_utils.py index 8435fb0..82a765f 100644 --- a/training/peft_utils.py +++ b/training/peft_utils.py @@ -7,6 +7,7 @@ class LoRALinear(nn.Module): def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): super().__init__() self.linear = linear_layer + self.lora_r = lora_r self.lora_alpha = lora_alpha self.lora_dropout = nn.Dropout(p=lora_dropout) @@ -19,10 +20,10 @@ def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): #nn.init.zeros_(self.lora_A.weight) self.scaling = self.lora_alpha / self.lora_r - + self.linear.requires_grad_(False) + def forward(self, x): - out = self.linear(x) - out = out + self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + out = self.linear(x) + self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling return out def replace_linear_with_lora_old(model, lora_r, lora_alpha, lora_dropout): From 3cf84eaf57b503c38527fadf257929b6814c07d7 Mon Sep 17 00:00:00 2001 From: sidhant ls Date: Tue, 29 Oct 2024 14:12:42 +0530 Subject: [PATCH 12/16] dont set grads with separate fn rather use the grads set during lora replacement --- training/run_parler_tts_training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 7d3f7eb..d1a25c1 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -336,11 +336,9 @@ def main(): ) if training_args.use_peft == True: + logger.info('\n--Using PEFT, replacing layers--\n') replace_linear_with_lora(model.decoder, lora_r=training_args.lora_r, lora_alpha=training_args.lora_alpha, lora_dropout=training_args.lora_dropout) - set_non_lora_gradients_to_false(model.decoder) - print_trainable_params(model) - # enable gradient checkpointing if necessary if training_args.gradient_checkpointing: model.gradient_checkpointing_enable() @@ -758,6 +756,8 @@ def compute_metrics( # Prepare everything with accelerate model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + print_trainable_params(model) # print trainable params, for lora + logger.info("***** Running training *****") logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}") logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}") From 714c23ff0900e4e2a0f8e75dc0345442401da499 Mon Sep 17 00:00:00 2001 From: sidhant ls Date: Wed, 30 Oct 2024 20:54:54 +0530 Subject: [PATCH 13/16] remove unused functions --- training/peft_utils.py | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/training/peft_utils.py b/training/peft_utils.py index 82a765f..5493fb2 100644 --- a/training/peft_utils.py +++ b/training/peft_utils.py @@ -17,7 +17,6 @@ def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) # following microsoft/LoRA nn.init.zeros_(self.lora_B.weight) - #nn.init.zeros_(self.lora_A.weight) self.scaling = self.lora_alpha / self.lora_r self.linear.requires_grad_(False) @@ -26,18 +25,10 @@ def forward(self, x): out = self.linear(x) + self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling return out -def replace_linear_with_lora_old(model, lora_r, lora_alpha, lora_dropout): - for name, module in model.named_modules(): - if any(item in name for item in ['embed_prompts', 'lm_heads']): - print('Ignored adding peft to ', name) - continue - - if isinstance(module, nn.Linear): - lora_linear = LoRALinear(module, lora_r, lora_alpha, lora_dropout) - setattr(model, name, lora_linear) - return model - def replace_linear_with_lora(model, lora_r, lora_alpha, lora_dropout): + """ + Given a model, replaces all linear layers with a Linear LORA layer in-place. Returns model + """ full_name_dict = {module: name for name, module in model.named_modules()} linear_info = {} modules = [model] @@ -74,12 +65,3 @@ def replace_linear_with_lora(model, lora_r, lora_alpha, lora_dropout): print('Replaced linear layers with low-rank layers.') return model -def set_non_lora_gradients_to_false(model): - for name, param in model.named_parameters(): - if "lora_" not in name: - param.requires_grad = False - - if 'lm_heads' in name or 'embed_prompts' in name: - param.requires_grad = True - print("Using gradients for lm_heads or embed_prompts", name) - return model From 2dc4875697f4fcc87c1093e3b11e4d6b90739ab1 Mon Sep 17 00:00:00 2001 From: sidhant ls Date: Thu, 31 Oct 2024 00:54:22 +0530 Subject: [PATCH 14/16] add fn to convert lora to linear --- training/peft_utils.py | 52 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/training/peft_utils.py b/training/peft_utils.py index 5493fb2..88f871c 100644 --- a/training/peft_utils.py +++ b/training/peft_utils.py @@ -25,6 +25,15 @@ def forward(self, x): out = self.linear(x) + self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling return out + def return_weights_without_lora(self): + """ + After adapters has been trained, this functions returns the final linear weights after original linear layer is + combined with trained peft adapters + """ + with torch.no_grad(): + new_weight = self.linear.weight + (self.lora_B.weight @ self.lora_A.weight) * self.scaling + return new_weight + def replace_linear_with_lora(model, lora_r, lora_alpha, lora_dropout): """ Given a model, replaces all linear layers with a Linear LORA layer in-place. Returns model @@ -65,3 +74,46 @@ def replace_linear_with_lora(model, lora_r, lora_alpha, lora_dropout): print('Replaced linear layers with low-rank layers.') return model + +def replace_lora_with_linear(model): + """ + Given a model that has LoRa adapters, this function replaces the trained Lora Linear layers into a regular nn.Linear + layer. This is done before saving the model, to remove LoRA adapters before model saving + """ + full_name_dict = {module: name for name, module in model.named_modules()} + linear_info = {} + modules = [model] + while len(modules) > 0: + submodule = modules.pop() + for name, raw_linear in submodule.named_children(): + if isinstance(raw_linear, LoRALinear): + full_name = full_name_dict[raw_linear] + linear_info[raw_linear] = { + "father": submodule, + "name": name, + "full_name": full_name, + } + else: + modules.append(raw_linear) + + for total_len, _ in enumerate(model.named_modules()): + pass + + i = 0 + for name, module in tqdm(model.named_modules(), total=total_len, desc='Removing LoRA layers', mininterval=5): + if module in linear_info: + info = linear_info[module] + + weight = module.linear.weight.data.clone() # Shape: [out_features, in_features] + new_linear_weight = module.return_weights_without_lora() + + new_module = nn.Linear(module.linear.in_features, module.linear.out_features) + new_module.weight.data = new_linear_weight + + setattr(info["father"], info["name"], new_module) + + del linear_info[module] + torch.cuda.empty_cache() + + print('Replaced linear layers with low-rank layers.') + return model From 81b27192019dfe374fcc05082b54a4f12d152062 Mon Sep 17 00:00:00 2001 From: sidhant ls Date: Thu, 31 Oct 2024 00:55:43 +0530 Subject: [PATCH 15/16] remove unused import --- training/run_parler_tts_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index d1a25c1..503ba26 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -64,7 +64,7 @@ from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding from training.eval import clap_similarity, wer, si_sdr -from training.peft_utils import replace_linear_with_lora, set_non_lora_gradients_to_false +from training.peft_utils import replace_linear_with_lora, replace_lora_with_linear logger = logging.getLogger(__name__) From 74f0c7e467c3642a482350a71e18d1fdbd5b6a00 Mon Sep 17 00:00:00 2001 From: sidhantls Date: Thu, 31 Oct 2024 17:34:57 +0000 Subject: [PATCH 16/16] add relu for training stability --- training/peft_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/training/peft_utils.py b/training/peft_utils.py index 82a765f..e9ff3f2 100644 --- a/training/peft_utils.py +++ b/training/peft_utils.py @@ -17,13 +17,12 @@ def __init__(self, linear_layer, lora_r, lora_alpha, lora_dropout): nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) # following microsoft/LoRA nn.init.zeros_(self.lora_B.weight) - #nn.init.zeros_(self.lora_A.weight) self.scaling = self.lora_alpha / self.lora_r self.linear.requires_grad_(False) def forward(self, x): - out = self.linear(x) + self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + out = self.linear(x) + torch.relu(self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling) return out def replace_linear_with_lora_old(model, lora_r, lora_alpha, lora_dropout):