diff --git a/index.html b/index.html index b45fe14..28026a9 100644 --- a/index.html +++ b/index.html @@ -106,6 +106,17 @@

Flash Attention

Read more ⟶ +
+

Sliding Window Attention

+ +
+ + Altering the tokens to which a token in the input sequence attends. + +
+ Read more ⟶ +
+

Sparse Attention

diff --git a/index.xml b/index.xml index 265030e..c25a1bc 100644 --- a/index.xml +++ b/index.xml @@ -22,6 +22,13 @@ https://www.jonahramponi.com/posts/flash_attention/ The goal of Flash Attention is to compute the attention value with fewer high bandwidth memory read / writes. The approach has since been refined in Flash Attention 2. We will split the attention inputs $Q,K,V$ into blocks. Each block will be handled separately, and attention will therefore be computed with respect to each block. With the correct scaling, adding the outputs from each block we will give us the same attention value as we would get by computing everything all together. + + Sliding Window Attention + https://www.jonahramponi.com/posts/sliding_window_attention/ + Fri, 22 Mar 2024 00:00:00 +0000 + https://www.jonahramponi.com/posts/sliding_window_attention/ + Sliding Window Attention reduces the number of calculations we are doing when computing self attention. Previously, to compute attention we took our input matrix of positional encodings $M$, and made copies named $Q, K$ and $V$. We used these copies to compute \begin{equation} \text{attention}(Q,K,V) = \text{softmax}\Big(\frac{Q K^T}{\sqrt{d_k}}\Big) V. \end{equation} For now, let’s ignore the re-scaling by $\sqrt{d_k}$ and just look at the computation of $QK^T$. This computation looks like \begin{equation} Q \times K^T = \begin{pmatrix} Q_{11} & Q_{12} & \cdots & Q_{1d} \\ Q_{21} & Q_{22} & \cdots & Q_{2d} \\ \vdots & \vdots & \ddots & \vdots \\ Q_{n1} & Q_{n2} & \cdots & Q_{nd} \end{pmatrix} \times \begin{pmatrix} K_{11} & K_{21} & \cdots & K_{n1} \\ K_{12} & K_{22} & \cdots & K_{n2} \\ \vdots & \vdots & \ddots & \vdots \\ K_{1d} & K_{2d} & \cdots & K_{nd} \end{pmatrix} \end{equation} + Sparse Attention https://www.jonahramponi.com/posts/sparse_attention/ diff --git a/posts/index.html b/posts/index.html index 2e07b6b..c3b2ab0 100644 --- a/posts/index.html +++ b/posts/index.html @@ -81,6 +81,8 @@

All articles

Post 2 Mar 30, 2024
  • Flash Attention Mar 26, 2024 +
  • + Sliding Window Attention Mar 22, 2024
  • Sparse Attention Mar 22, 2024
  • diff --git a/posts/index.xml b/posts/index.xml index c1600e7..c791de9 100644 --- a/posts/index.xml +++ b/posts/index.xml @@ -22,6 +22,13 @@ https://www.jonahramponi.com/posts/flash_attention/ The goal of Flash Attention is to compute the attention value with fewer high bandwidth memory read / writes. The approach has since been refined in Flash Attention 2. We will split the attention inputs $Q,K,V$ into blocks. Each block will be handled separately, and attention will therefore be computed with respect to each block. With the correct scaling, adding the outputs from each block we will give us the same attention value as we would get by computing everything all together. + + Sliding Window Attention + https://www.jonahramponi.com/posts/sliding_window_attention/ + Fri, 22 Mar 2024 00:00:00 +0000 + https://www.jonahramponi.com/posts/sliding_window_attention/ + Sliding Window Attention reduces the number of calculations we are doing when computing self attention. Previously, to compute attention we took our input matrix of positional encodings $M$, and made copies named $Q, K$ and $V$. We used these copies to compute \begin{equation} \text{attention}(Q,K,V) = \text{softmax}\Big(\frac{Q K^T}{\sqrt{d_k}}\Big) V. \end{equation} For now, let’s ignore the re-scaling by $\sqrt{d_k}$ and just look at the computation of $QK^T$. This computation looks like \begin{equation} Q \times K^T = \begin{pmatrix} Q_{11} & Q_{12} & \cdots & Q_{1d} \\ Q_{21} & Q_{22} & \cdots & Q_{2d} \\ \vdots & \vdots & \ddots & \vdots \\ Q_{n1} & Q_{n2} & \cdots & Q_{nd} \end{pmatrix} \times \begin{pmatrix} K_{11} & K_{21} & \cdots & K_{n1} \\ K_{12} & K_{22} & \cdots & K_{n2} \\ \vdots & \vdots & \ddots & \vdots \\ K_{1d} & K_{2d} & \cdots & K_{nd} \end{pmatrix} \end{equation} + Sparse Attention https://www.jonahramponi.com/posts/sparse_attention/ diff --git a/posts/sliding_window_attention/index.html b/posts/sliding_window_attention/index.html new file mode 100644 index 0000000..7f7695c --- /dev/null +++ b/posts/sliding_window_attention/index.html @@ -0,0 +1,169 @@ + + + + Sliding Window Attention - Jonah's ML Notes + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + +
    + +
    +
    +
    +

    Sliding Window Attention

    +
    Posted on Mar 22, 2024
    +
    + +
    + tl;dr: + Altering the tokens to which a token in the input sequence attends. +
    + +
    +

    Sliding Window Attention reduces the number of calculations we are doing when computing self attention. Previously, to compute attention we took our input matrix of positional encodings $M$, and made copies named $Q, K$ and $V$. We used these copies to compute

    +

    \begin{equation} +\text{attention}(Q,K,V) = \text{softmax}\Big(\frac{Q K^T}{\sqrt{d_k}}\Big) V. +\end{equation}

    +

    For now, let’s ignore the re-scaling by $\sqrt{d_k}$ and just look at the computation of $QK^T$. This computation looks like +\begin{equation} +Q \times K^T = \begin{pmatrix} +Q_{11} & Q_{12} & \cdots & Q_{1d} \\ +Q_{21} & Q_{22} & \cdots & Q_{2d} \\ +\vdots & \vdots & \ddots & \vdots \\ +Q_{n1} & Q_{n2} & \cdots & Q_{nd} +\end{pmatrix} \times +\begin{pmatrix} +K_{11} & K_{21} & \cdots & K_{n1} \\ +K_{12} & K_{22} & \cdots & K_{n2} \\ +\vdots & \vdots & \ddots & \vdots \\ +K_{1d} & K_{2d} & \cdots & K_{nd} +\end{pmatrix} +\end{equation}

    +

    Our goal is to simplify this computation. Instead of letting each token attend to all of the other tokens, we will define a window size $w$. The token we are calculating attention values for will then only get to look at the tokens $\frac{1}{2}w$ either side of it. For our example, we could consider a sliding window of size $2$ which will look $1$ token to either side of the current token. Only the values shaded in \colorbox{olive}{olive} will be calculated.

    +

    Sliding Window Attention Matrix

    +

    This greatly reduces the cost of the computation of $Q \times K^T$, as our computation will now look like

    +

    \begin{equation} +Q \times K^T = \begin{pmatrix} +Q_{11} & Q_{12} & &\\ +Q_{21} & Q_{22} & \cdots & \\ +& \vdots & \ddots & \vdots \\ +& & \cdots & Q_{nd} +\end{pmatrix} \times +\begin{pmatrix} +K_{11} & K_{21} & & \\ +K_{12} & K_{22} & \cdots & \\ +& \vdots & \ddots & \vdots \\ +& & \cdots & K_{nd} +\end{pmatrix} +\end{equation}

    +

    However, the original authors encountered a problem in training. The authors found that this approach is not flexible enough to learn to complete specific tasks. They solved this problem through the introduction of \textit{global attention}. This will give a few of our tokens some special properties:

    +

    \begin{itemize} +\item A token with a global attention attends to all other tokens in the sequence +\item All tokens in the sequence attend to every token with a global attention. +\end{itemize}

    +

    The local attention (sliding window attention) is primarily used to build contextual representations, while the global attention allows the model to build full sequence representations for prediction.

    +

    We will require two sets of our projection matrices. Firstly, projections to compute attention scores for our sliding window approach ${Q_s, K_s, V_s}$ and secondly attention scores for the global attention ${Q_g,K_g,V_g}$. These are initialized to the same values.

    +

    We first calculate local attention weights using ${Q_s,K_s,V_s}$. This gives us an attention output, which is then combined with the output using the global attention weights. The global weights are written on top of the output attention weight matrix calculated by the local attention calculation.

    +

    \textbf{Dilated Sliding Window Attention} is another approach to achieve a similar result. This time, instead of simply taking the $\frac{1}{2}w$ tokens either side of a given $w$ we will introduce some gaps of size $d$. This is referred to as the dilation. Using $w=2, d=1$ in our example we would have an attention matrix which looks like

    +

    Dilated Sliding Window Attention Matrix

    +

    The authors provide a nice visual of how this looks generally, which you can see in Figure (\ref{fig:longform}). The authors note they use dilated sliding window attention with small window sizes for lower layers, and larger window sizes for higher layers. They do not introduce dilation for lower layers, however for higher layers a small amount of increasing dilation was introduced on $2$ heads.

    +

    Attention Matrix Visualizations from the Longformer Paper

    + +
    + + +
    +
    + + + +
    + + diff --git a/sitemap.xml b/sitemap.xml index 102d564..9957b8e 100644 --- a/sitemap.xml +++ b/sitemap.xml @@ -22,6 +22,9 @@ https://www.jonahramponi.com/tags/ 2024-03-26T00:00:00+00:00 + + https://www.jonahramponi.com/posts/sliding_window_attention/ + 2024-03-22T00:00:00+00:00 https://www.jonahramponi.com/posts/sparse_attention/ 2024-03-22T00:00:00+00:00 diff --git a/tags/attention/index.html b/tags/attention/index.html index 0557a4a..1bead0d 100644 --- a/tags/attention/index.html +++ b/tags/attention/index.html @@ -79,6 +79,8 @@

    Entries tagged - "attention"

    • Flash Attention Mar 26, 2024 +
    • + Sliding Window Attention Mar 22, 2024
    • Sparse Attention Mar 22, 2024
    • diff --git a/tags/attention/index.xml b/tags/attention/index.xml index 74aef5f..6e210a0 100644 --- a/tags/attention/index.xml +++ b/tags/attention/index.xml @@ -15,6 +15,13 @@ https://www.jonahramponi.com/posts/flash_attention/ The goal of Flash Attention is to compute the attention value with fewer high bandwidth memory read / writes. The approach has since been refined in Flash Attention 2. We will split the attention inputs $Q,K,V$ into blocks. Each block will be handled separately, and attention will therefore be computed with respect to each block. With the correct scaling, adding the outputs from each block we will give us the same attention value as we would get by computing everything all together. + + Sliding Window Attention + https://www.jonahramponi.com/posts/sliding_window_attention/ + Fri, 22 Mar 2024 00:00:00 +0000 + https://www.jonahramponi.com/posts/sliding_window_attention/ + Sliding Window Attention reduces the number of calculations we are doing when computing self attention. Previously, to compute attention we took our input matrix of positional encodings $M$, and made copies named $Q, K$ and $V$. We used these copies to compute \begin{equation} \text{attention}(Q,K,V) = \text{softmax}\Big(\frac{Q K^T}{\sqrt{d_k}}\Big) V. \end{equation} For now, let’s ignore the re-scaling by $\sqrt{d_k}$ and just look at the computation of $QK^T$. This computation looks like \begin{equation} Q \times K^T = \begin{pmatrix} Q_{11} & Q_{12} & \cdots & Q_{1d} \\ Q_{21} & Q_{22} & \cdots & Q_{2d} \\ \vdots & \vdots & \ddots & \vdots \\ Q_{n1} & Q_{n2} & \cdots & Q_{nd} \end{pmatrix} \times \begin{pmatrix} K_{11} & K_{21} & \cdots & K_{n1} \\ K_{12} & K_{22} & \cdots & K_{n2} \\ \vdots & \vdots & \ddots & \vdots \\ K_{1d} & K_{2d} & \cdots & K_{nd} \end{pmatrix} \end{equation} + Sparse Attention https://www.jonahramponi.com/posts/sparse_attention/ diff --git a/tags/inference/index.html b/tags/inference/index.html index 5b6dd78..6fa7a3e 100644 --- a/tags/inference/index.html +++ b/tags/inference/index.html @@ -79,6 +79,8 @@

      Entries tagged - "inference"

      • Flash Attention Mar 26, 2024 +
      • + Sliding Window Attention Mar 22, 2024
      • Sparse Attention Mar 22, 2024
      • diff --git a/tags/inference/index.xml b/tags/inference/index.xml index 90b1075..83c0ba3 100644 --- a/tags/inference/index.xml +++ b/tags/inference/index.xml @@ -15,6 +15,13 @@ https://www.jonahramponi.com/posts/flash_attention/ The goal of Flash Attention is to compute the attention value with fewer high bandwidth memory read / writes. The approach has since been refined in Flash Attention 2. We will split the attention inputs $Q,K,V$ into blocks. Each block will be handled separately, and attention will therefore be computed with respect to each block. With the correct scaling, adding the outputs from each block we will give us the same attention value as we would get by computing everything all together. + + Sliding Window Attention + https://www.jonahramponi.com/posts/sliding_window_attention/ + Fri, 22 Mar 2024 00:00:00 +0000 + https://www.jonahramponi.com/posts/sliding_window_attention/ + Sliding Window Attention reduces the number of calculations we are doing when computing self attention. Previously, to compute attention we took our input matrix of positional encodings $M$, and made copies named $Q, K$ and $V$. We used these copies to compute \begin{equation} \text{attention}(Q,K,V) = \text{softmax}\Big(\frac{Q K^T}{\sqrt{d_k}}\Big) V. \end{equation} For now, let’s ignore the re-scaling by $\sqrt{d_k}$ and just look at the computation of $QK^T$. This computation looks like \begin{equation} Q \times K^T = \begin{pmatrix} Q_{11} & Q_{12} & \cdots & Q_{1d} \\ Q_{21} & Q_{22} & \cdots & Q_{2d} \\ \vdots & \vdots & \ddots & \vdots \\ Q_{n1} & Q_{n2} & \cdots & Q_{nd} \end{pmatrix} \times \begin{pmatrix} K_{11} & K_{21} & \cdots & K_{n1} \\ K_{12} & K_{22} & \cdots & K_{n2} \\ \vdots & \vdots & \ddots & \vdots \\ K_{1d} & K_{2d} & \cdots & K_{nd} \end{pmatrix} \end{equation} + Sparse Attention https://www.jonahramponi.com/posts/sparse_attention/