You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
--- a/model.py
+++ b/model.py
@@ -7,6 +7,7 @@
import torch.nn as nn
from torch.nn import functional as F
+from torch.linalg import norm
from dataclasses import dataclass
@dataclass
@@ -128,7 +129,66 @@
return logits, loss
+class PrincipalAttentionPursuit(nn.Module):
+ """
+ Implements the Principal Attention Pursuit (PAP) algorithm for RPC-Attention.
+ """
+
+ def __init__(self, n_iter, lambda_):
+ super().__init__()
+ self.n_iter = n_iter
+ self.lambda_ = lambda_
+
+ def shrinkage_operator(self, x, tau):
+ return torch.sign(x) * torch.relu(torch.abs(x) - tau)
+
+ def forward(self, q, k):
+ # Initialize S and Y
+ B, T, C = k.size()
+ S = torch.zeros_like(k)
+ Y = torch.zeros_like(k)
+
+ # Calculate mu
+ mu = (B * T) / (4 * norm(k, p=1))
+
+ # Ensure mu and lambda_ are on the same device as k
+ mu = mu.to(k.device)
+ lambda_ = self.lambda_.to(k.device)
+
+ # Iterative refinement
+ L = k.clone() # Initialize L to k
+ for _ in range(self.n_iter):
+ S = self.shrinkage_operator(L - (mu**-1) * Y, lambda_ / mu)
+ # Softmax will be applied in the forward pass of the calling CausalSelfAttention
+ # Here we use L in place of K in the attention mechanism
+ # Thus passing (q, L) instead of (q,k) to the forward of CausalSelfAttention
+ L = L # In this implementation L remains unchanged through the iterations
+ Y = Y + mu * (k - L - S)
+
+ return L # Return refined key matrix L
+
class CausalSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ assert config.n_embd % config.n_head == 0
+ # key, query, value projections for all heads, but in a batch
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
+ # output projection
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
+ # regularization
+ self.attn_dropout = nn.Dropout(config.dropout)
+ self.resid_dropout = nn.Dropout(config.dropout)
+ self.n_head = config.n_head
+ self.n_embd = config.n_embd
+ self.dropout = config.dropout
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
+ if not self.flash:
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
+ .view(1, 1, config.block_size, config.block_size))
+
+ # RPC-Attention parameters
+ self.use_rpc_attention = config.use_rpc_attention
+ if self.use_rpc_attention:
+ self.pap = PrincipalAttentionPursuit(n_iter=config.n_iter, lambda_=torch.tensor(config.lambda_))
+
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
@@ -153,14 +213,20 @@
# q, k, v = (x.transpose(1, 2) for x in self.c_attn(x).split(self.n_embd, dim=2))
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ if self.use_rpc_attention:
+ # Apply PAP to refine the key matrix
+ k = self.pap(q, k)
+
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
- y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None,
+ dropout_p=self.dropout if self.training else 0,
+ is_causal=True)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@@ -202,6 +268,7 @@
super().__init__()
self.block_size = config.block_size
self.transformer = nn.ModuleDict(dict(
+ # Removed wte and wpe initializations to only initialize them in the GPT class
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = LayerNorm(config.n_embd, bias=config.bias),
@@ -216,6 +283,28 @@
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
+ if pn.endswith('c_proj.weight'):
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
+
+ def get_num_params(self, non_embedding=True):
+ """
+ Return the number of parameters in the model.
+ For non-embedding count (default), the position embeddings get subtracted.
+ The token embeddings would too, except due to the parameter sharing these
+ params are actually used as weights in the final layer, so we include them.
+ """
+ n_params = sum(p.numel() for p in self.parameters())
+ if non_embedding:
+ n_params -= self.transformer.wpe.weight.numel()
+ return n_params
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+ if module.bias is not None:
+ torch.nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
@@ -262,6 +351,8 @@
self.use_flash = use_flash
self.lambda_ = lambda_
self.n_iter = n_iter
+ self.use_rpc_attention = use_rpc_attention
+
def save(self, filepath):
"""
@@ -285,9 +376,9 @@
block_size=self.block_size,
n_layer=self.n_layer,
n_head=self.n_head,
- n_embd=self.n_embd,
dropout=self.dropout,
bias=self.bias,
+ n_embd=self.n_embd,
use_flash=self.use_flash,
lambda_ = self.lambda_,
n_iter = self.n_iter,
@@ -297,6 +388,10 @@
"""
# Override the default implementation to initialize the model using the config
config = self.get_config()
+
+ # Initialize token and position embeddings here
+ config.transformer.wte = torch.nn.Embedding(config.vocab_size, config.n_embd)
+ config.transformer.wpe = torch.nn.Embedding(config.block_size, config.n_embd)
+
model = GPT(config)
return model
The text was updated successfully, but these errors were encountered:
gkielian
changed the title
Test RPC attention
Add model variation for RPC attention
Jan 10, 2025
Attempt to replicate the RPC Style of attention, review the following:
https://arxiv.org/abs/2406.13762
The text was updated successfully, but these errors were encountered: