From 8560a65afff2552a50ace6ed6595fe8eb1593f89 Mon Sep 17 00:00:00 2001 From: Carolin Teuber <115626873+caroteu@users.noreply.github.com> Date: Wed, 13 Nov 2024 19:25:56 +0100 Subject: [PATCH] Add LoRA scaling factor (#770) Add scaling parameter for the final outputs - Our preliminary experiments show that the scaling parameter is detrimental for >1 and <1. We set the parameter to 1. --------- Co-authored-by: Anwai Archit <52396323+anwai98@users.noreply.github.com> --- micro_sam/models/peft_sam.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index febbccf6..c1b5dcc6 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -27,11 +27,13 @@ def __init__(self, rank: int, block: nn.Module): super().__init__() self.qkv_proj = block.attn.qkv self.dim = self.qkv_proj.in_features + self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. + self.rank = rank - self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) - self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) - self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False) - self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False) + self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) + self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) + self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) + self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) self.reset_parameters() @@ -45,8 +47,8 @@ def reset_parameters(self): def forward(self, x): qkv = self.qkv_proj(x) # B, N, N, 3 * org_C - new_q = self.w_b_linear_q(self.w_a_linear_q(x)) - new_v = self.w_b_linear_v(self.w_a_linear_v(x)) + new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) + new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) qkv[:, :, :, :self.dim] += new_q qkv[:, :, :, -self.dim:] += new_v return qkv @@ -123,7 +125,7 @@ def allow_gradient_update_for_parameters( Args: prefix: Matches the part of parameter name in front. suffix: Matches the part of parameter name at the end. - infix: Matches parts of parameter name occuring in between. + infix: Matches parts of parameter name occuring in between. """ for k, v in self.block.named_parameters(): if prefix is not None and k.startswith(tuple(prefix)):