Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model variation for RPC attention #353

Open
gkielian opened this issue Jan 10, 2025 · 0 comments
Open

Add model variation for RPC attention #353

gkielian opened this issue Jan 10, 2025 · 0 comments

Comments

@gkielian
Copy link
Collaborator

gkielian commented Jan 10, 2025

Attempt to replicate the RPC Style of attention, review the following:

https://arxiv.org/abs/2406.13762

--- a/model.py
+++ b/model.py
@@ -7,6 +7,7 @@
 import torch.nn as nn
 from torch.nn import functional as F
 
+from torch.linalg import norm
 from dataclasses import dataclass
 
 @dataclass
@@ -128,7 +129,66 @@
 
         return logits, loss
 
+class PrincipalAttentionPursuit(nn.Module):
+    """
+    Implements the Principal Attention Pursuit (PAP) algorithm for RPC-Attention.
+    """
+
+    def __init__(self, n_iter, lambda_):
+        super().__init__()
+        self.n_iter = n_iter
+        self.lambda_ = lambda_
+
+    def shrinkage_operator(self, x, tau):
+        return torch.sign(x) * torch.relu(torch.abs(x) - tau)
+
+    def forward(self, q, k):
+        # Initialize S and Y
+        B, T, C = k.size()
+        S = torch.zeros_like(k)
+        Y = torch.zeros_like(k)
+        
+        # Calculate mu
+        mu = (B * T) / (4 * norm(k, p=1)) 
+
+        # Ensure mu and lambda_ are on the same device as k
+        mu = mu.to(k.device)
+        lambda_ = self.lambda_.to(k.device)
+
+        # Iterative refinement
+        L = k.clone() # Initialize L to k
+        for _ in range(self.n_iter):
+            S = self.shrinkage_operator(L - (mu**-1) * Y, lambda_ / mu)
+            # Softmax will be applied in the forward pass of the calling CausalSelfAttention
+            # Here we use L in place of K in the attention mechanism
+            # Thus passing (q, L) instead of (q,k) to the forward of CausalSelfAttention
+            L = L # In this implementation L remains unchanged through the iterations
+            Y = Y + mu * (k - L - S)
+
+        return L # Return refined key matrix L
+
 class CausalSelfAttention(nn.Module):
 
+    def __init__(self, config):
+        super().__init__()
+        assert config.n_embd % config.n_head == 0
+        # key, query, value projections for all heads, but in a batch
+        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
+        # output projection
+        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
+        # regularization
+        self.attn_dropout = nn.Dropout(config.dropout)
+        self.resid_dropout = nn.Dropout(config.dropout)
+        self.n_head = config.n_head
+        self.n_embd = config.n_embd
+        self.dropout = config.dropout
+        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
+        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
+        if not self.flash:
+            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
+            # causal mask to ensure that attention is only applied to the left in the input sequence
+            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
+                                        .view(1, 1, config.block_size, config.block_size))
+        
+        # RPC-Attention parameters
+        self.use_rpc_attention = config.use_rpc_attention
+        if self.use_rpc_attention:
+            self.pap = PrincipalAttentionPursuit(n_iter=config.n_iter, lambda_=torch.tensor(config.lambda_))
+
     def __init__(self, config):
         super().__init__()
         assert config.n_embd % config.n_head == 0
@@ -153,14 +213,20 @@
         # q, k, v = (x.transpose(1, 2) for x in self.c_attn(x).split(self.n_embd, dim=2))
         q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
         
-        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
-        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
-        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+        if self.use_rpc_attention:
+            # Apply PAP to refine the key matrix
+            k = self.pap(q, k)
+            
+        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
+        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
+        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
 
         # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
         if self.flash:
             # efficient attention using Flash Attention CUDA kernels
-            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
+            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None,
+                                                                 dropout_p=self.dropout if self.training else 0,
+                                                                 is_causal=True)
         else:
             # manual implementation of attention
             att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@@ -202,6 +268,7 @@
         super().__init__()
         self.block_size = config.block_size
         self.transformer = nn.ModuleDict(dict(
+            # Removed wte and wpe initializations to only initialize them in the GPT class
             drop = nn.Dropout(config.dropout),
             h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
             ln_f = LayerNorm(config.n_embd, bias=config.bias),
@@ -216,6 +283,28 @@
         self.apply(self._init_weights)
         # apply special scaled init to the residual projections, per GPT-2 paper
         for pn, p in self.named_parameters():
+            if pn.endswith('c_proj.weight'):
+                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
+
+    def get_num_params(self, non_embedding=True):
+        """
+        Return the number of parameters in the model.
+        For non-embedding count (default), the position embeddings get subtracted.
+        The token embeddings would too, except due to the parameter sharing these
+        params are actually used as weights in the final layer, so we include them.
+        """
+        n_params = sum(p.numel() for p in self.parameters())
+        if non_embedding:
+            n_params -= self.transformer.wpe.weight.numel()
+        return n_params
+
+    def _init_weights(self, module):
+        if isinstance(module, nn.Linear):
+            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+            if module.bias is not None:
+                torch.nn.init.zeros_(module.bias)
+        elif isinstance(module, nn.Embedding):
+            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
             if pn.endswith('c_proj.weight'):
                 torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
 
@@ -262,6 +351,8 @@
         self.use_flash = use_flash
         self.lambda_ = lambda_
         self.n_iter = n_iter
+        self.use_rpc_attention = use_rpc_attention
+        
 
     def save(self, filepath):
         """
@@ -285,9 +376,9 @@
             block_size=self.block_size,
             n_layer=self.n_layer,
             n_head=self.n_head,
-            n_embd=self.n_embd,
             dropout=self.dropout,
             bias=self.bias,
+            n_embd=self.n_embd,
             use_flash=self.use_flash,
             lambda_ = self.lambda_,
             n_iter = self.n_iter,
@@ -297,6 +388,10 @@
         """
         # Override the default implementation to initialize the model using the config
         config = self.get_config()
+        
+        # Initialize token and position embeddings here
+        config.transformer.wte = torch.nn.Embedding(config.vocab_size, config.n_embd)
+        config.transformer.wpe = torch.nn.Embedding(config.block_size, config.n_embd)
+        
         model = GPT(config)
         return model

@gkielian gkielian changed the title Test RPC attention Add model variation for RPC attention Jan 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant