Skip to content

Commit

Permalink
Add LoRA scaling factor (#770)
Browse files Browse the repository at this point in the history
Add scaling parameter for the final outputs
- Our preliminary experiments show that the scaling parameter is detrimental for >1 and <1. We set the parameter to 1.

---------
Co-authored-by: Anwai Archit <[email protected]>
  • Loading branch information
caroteu authored Nov 13, 2024
1 parent 8801da3 commit 8560a65
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ def __init__(self, rank: int, block: nn.Module):
super().__init__()
self.qkv_proj = block.attn.qkv
self.dim = self.qkv_proj.in_features
self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance.
self.rank = rank

self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False)
self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False)
self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)

self.reset_parameters()

Expand All @@ -45,8 +47,8 @@ def reset_parameters(self):

def forward(self, x):
qkv = self.qkv_proj(x) # B, N, N, 3 * org_C
new_q = self.w_b_linear_q(self.w_a_linear_q(x))
new_v = self.w_b_linear_v(self.w_a_linear_v(x))
new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x))
new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x))
qkv[:, :, :, :self.dim] += new_q
qkv[:, :, :, -self.dim:] += new_v
return qkv
Expand Down Expand Up @@ -123,7 +125,7 @@ def allow_gradient_update_for_parameters(
Args:
prefix: Matches the part of parameter name in front.
suffix: Matches the part of parameter name at the end.
infix: Matches parts of parameter name occuring in between.
infix: Matches parts of parameter name occuring in between.
"""
for k, v in self.block.named_parameters():
if prefix is not None and k.startswith(tuple(prefix)):
Expand Down

0 comments on commit 8560a65

Please sign in to comment.