diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index febbccf6..c1b5dcc6 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -27,11 +27,13 @@ def __init__(self, rank: int, block: nn.Module): super().__init__() self.qkv_proj = block.attn.qkv self.dim = self.qkv_proj.in_features + self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. + self.rank = rank - self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) - self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) - self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False) - self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False) + self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) + self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) + self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) + self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) self.reset_parameters() @@ -45,8 +47,8 @@ def reset_parameters(self): def forward(self, x): qkv = self.qkv_proj(x) # B, N, N, 3 * org_C - new_q = self.w_b_linear_q(self.w_a_linear_q(x)) - new_v = self.w_b_linear_v(self.w_a_linear_v(x)) + new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) + new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) qkv[:, :, :, :self.dim] += new_q qkv[:, :, :, -self.dim:] += new_v return qkv @@ -123,7 +125,7 @@ def allow_gradient_update_for_parameters( Args: prefix: Matches the part of parameter name in front. suffix: Matches the part of parameter name at the end. - infix: Matches parts of parameter name occuring in between. + infix: Matches parts of parameter name occuring in between. """ for k, v in self.block.named_parameters(): if prefix is not None and k.startswith(tuple(prefix)):