Skip to content

Commit

Permalink
more tidying
Browse files Browse the repository at this point in the history
  • Loading branch information
jonah-ramponi committed Mar 30, 2024
1 parent 387db44 commit 8d47333
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 26 deletions.
22 changes: 3 additions & 19 deletions content/posts/sliding_window_attention.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ For now, let's ignore the re-scaling by $\sqrt{d_k}$ and just look at the comput
\begin{equation}
Q \times K^T = \begin{pmatrix}
Q_{11} & Q_{12} & \cdots & Q_{1d} \\\\
Q_{21} & Q_{22} & \cdots & Q_{2d} \\\\
\vdots & \ddots & \cdots & \vdots \\\\
Q_{n1} & Q_{n2} & \cdots & Q_{nd}
\end{pmatrix} \times
\begin{pmatrix}
K_{11} & K_{21} & \cdots & K_{n1} \\\\
K_{12} & K_{22} & \cdots & K_{n2} \\\\
\vdots & \ddots & \cdots & \vdots \\\\
K_{1d} & K_{2d} & \cdots & K_{nd}
\end{pmatrix}
\end{equation}
Expand All @@ -31,23 +31,7 @@ Our goal is to simplify this computation. Instead of letting each token attend t

![Sliding Window Attention Matrix](/img/sliding_window.png)

This greatly reduces the cost of the computation of $Q \times K^T$, as our computation will now look like

\begin{equation}
Q \times K^T = \begin{pmatrix}
Q_{11} & Q_{12} & &\\\\
Q_{21} & Q_{22} & \cdots & \\\\
& & \cdots & Q_{nd}
\end{pmatrix} \times
\begin{pmatrix}
K_{11} & K_{21} & & \\\\
K_{12} & K_{22} & \cdots & \\\\
& & \cdots & K_{nd}
\end{pmatrix}
\end{equation}


However, the original authors encountered a problem in training. The authors found that this approach is not flexible enough to learn to complete specific tasks. They solved this problem through the introduction of \textit{global attention}. This will give a few of our tokens some special properties: A token with a global attention attends to all other tokens in the sequence and all tokens in the sequence attend to every token with a global attention.
This greatly reduces the cost of the computation of $Q \times K^T$, however, the original authors encountered a problem in training. The authors found that this approach is not flexible enough to learn to complete specific tasks. They solved this problem through the introduction of \textit{global attention}. This will give a few of our tokens some special properties: A token with a global attention attends to all other tokens in the sequence and all tokens in the sequence attend to every token with a global attention.

The local attention (sliding window attention) is primarily used to build contextual representations, while the global attention allows the model to build full sequence representations for prediction.

Expand Down
11 changes: 4 additions & 7 deletions content/posts/sparse_attention.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ tags: [attention, inference]

Here, we have defined

$ Q_{S_i} = (W_q x_j)_{j \text{ in } S_i}$

$$ Q_{S_i} = (W_q x_j), K_{S_i} = (W_k x_j), V_{S_i} = (W_v x_j) \text{ for } j \in S_i $$

So how do we define the set of connectivity patterns $S$? Formally, we let $S_i = A_i^{h}$ for head $h$ where $A_i^{h} \subset \{j : j \leq i\}$. It is still no clearer how we pick which indices we should take for a given $S_i$. The original authors consider two key criteria initially:
Expand All @@ -39,12 +37,11 @@ We now investigate two different approaches that satisfy this criteria, and allo

Here, $A_i^{(1)}$ simply takes the previous $l$ locations. $A_i^{(2)}$ then takes every $l$th head from the first head where $i-j$ was divisible by $l$ without remainder. This is particularly useful where you can align the structure of your input with the stride. For instance, with a piece of music. Where our input does not have a well defined structured, we use something different. In the image below, you can see $A_i^{(1)}$ responsible for the dark blue shading and $A_i^{(2)}$ responsible for the light blue.

**Fixed Attention*.* Our goal with this approach is to allow specific cells to summarize the previous locations, and to propagate this information on to future cells.
**Fixed Attention**. Our goal with this approach is to allow specific cells to summarize the previous locations, and to propagate this information on to future cells.

$$ A^{(1)}_i = \Big\{ j : \text{floor}(\frac{j}{l}) = \text{floor}( \frac{i}{l}) \Big\}, $$

\begin{align*}
A^{(1)}_i &= \Big\{ j : \text{floor}(\frac{j}{l}) = \text{floor}( \frac{i}{l}) \Big\}, \\\\
A^{(2)}_i &= \Big\{ j : j \mod l \in \{ t, t + 1, \ldots, l \} \Big\}, \text{ where } t = l - c \text{ and } c \text{ is a hyperparameter.}
\end{align*}
$$ A^{(2)}_i = \Big\{ j : j \mod l \in \{ t, t + 1, \ldots, l \} \Big\}, \text{ where } t = l - c \text{ and } c \text{ is a hyperparameter.} $$

These are best understood visually in my opinion. In the image below, $A_i^{(1)}$ is responsible for the dark blue shading and $A_i^{(2)}$ for the light blue shading. If we take stride, $l$ = 128 and $c=8$, then all positions greater than 128 can attend to positions $120-128$. The authors find choosing $c \in \{8,16,32\}$ worked well.

Expand Down

0 comments on commit 8d47333

Please sign in to comment.