Skip to content

Commit

Permalink
also add lora in SelfAttention (for the value proj)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 committed Mar 8, 2024
1 parent 9bc1ad8 commit 5272a71
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions egs/librispeech/ASR/zipformer_lora/zipformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,23 @@ def __init__(
lora_dropout=lora_dropout,
)

self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
self.self_attn1 = SelfAttention(
embed_dim,
num_heads,
value_head_dim,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)

self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
self.self_attn2 = SelfAttention(
embed_dim,
num_heads,
value_head_dim,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)

self.feed_forward1 = FeedforwardModule(
embed_dim, (feedforward_dim * 3) // 4, dropout
Expand Down Expand Up @@ -1901,9 +1915,19 @@ def __init__(
embed_dim: int,
num_heads: int,
value_head_dim: int,
lora_r: int = 0,
lora_alpha: int = 4,
lora_dropout: float=0.0
) -> None:
super().__init__()
self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
self.in_proj = ScaledLinear_lora(
in_features=embed_dim,
out_features=num_heads * value_head_dim,
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias=True,
)

self.out_proj = ScaledLinear(
num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
Expand Down

0 comments on commit 5272a71

Please sign in to comment.