Skip to content

Commit

Permalink
Add Linear attention variation
Browse files Browse the repository at this point in the history
This implements linear attention variation as describted mathematically in
https://arxiv.org/abs/2006.16236
  • Loading branch information
gkielian committed Feb 5, 2025
1 parent 9f2f909 commit 60408b2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
2 changes: 1 addition & 1 deletion train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def parse_args():
"--attention_variant",
type=str,
default="causal",
choices=["causal"],
choices=["causal", "linear"],
help="Which attention variant to use for the Transformer blocks."
)

Expand Down
54 changes: 54 additions & 0 deletions variations/attention_variations.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,60 @@ def forward(self, x, iter_num):

return y

class LinearAttention(nn.Module):
""" Implementation of Linear Attention
For algorithm description please see:
arxiv: https://arxiv.org/abs/2006.16236
"""
def __init__(self, config, fire_pos_enc=None):
super().__init__()
assert config.n_embd % config.n_head == 0

self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_size = config.n_embd // config.n_head

# Combined linear layer for q, k, v
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
self.c_proj = nn.Linear(config.n_embd, config.n_embd)

self.scale = torch.nn.Parameter(torch.tensor(1.0 / math.sqrt(self.head_size)))


def forward(self, x, iter_num=None):
B, T, C = x.size()

q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

q = q.view(B, T, self.n_head, self.head_size)
k = k.view(B, T, self.n_head, self.head_size)
v = v.view(B, T, self.n_head, self.head_size)

# NEW: Scale BEFORE the feature map
q = q * self.scale
k = k * self.scale

q = F.elu(q) + 1
k = F.elu(k) + 1

q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

kv = k * v
k_cumsum = k.cumsum(dim=2)
kv_cumsum = kv.cumsum(dim=2)


eps = 1e-3 # Increased epsilon
y = torch.einsum("BHTD,BHTD->BHTD", q, kv_cumsum) / (torch.einsum("BHTD,BHTD->BHT", q, k_cumsum)[..., None].clamp(min=eps))

y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)

return y

attention_dictionary = {
"causal": CausalSelfAttention,
"linear": LinearAttention,
}

0 comments on commit 60408b2

Please sign in to comment.