|
| 1 | +# Sliding Window Attention (SWA) |
| 2 | + |
| 3 | +This bonus material illustrates the memory savings when using Sliding Window Attention (SWA) over regular Multi-Head Attention (MHA). |
| 4 | + |
| 5 | + |
| 6 | + |
| 7 | + |
| 8 | +## Introduction |
| 9 | + |
| 10 | +What is sliding window attention (SWA)? If we think of regular self-attention as a *global* attention mechanism, since each sequence element can access every other sequence element, then we can think of SWA as *local* attention, because here we restrict the context size around the current query position. This is illustrated in the figure below. |
| 11 | + |
| 12 | +<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/swa-memory/1.webp?2" alt="Sliding Window Attention" width="500px" /> |
| 13 | + |
| 14 | +As shown in the figure above, instead of attending to all previous tokens, each token only attends to a fixed-size local window around its position. This localized attention lowers the size of the KV cache substantially. |
| 15 | + |
| 16 | +In the remainder of this introduction, we will discuss SWA in the context of [Gemma 3](https://arxiv.org/abs/2503.19786), which is implemented from scratch in [../../ch05/12_gemma3](../../ch05/12_gemma3). |
| 17 | + |
| 18 | +Sliding window attention was originally introduced in the [LongFormer paper in 2020](https://arxiv.org/abs/2004.05150), but the reason we focus on Google's Gemma models is that they are very good open-weight models showing that sliding window attention is indeed a feasible approach in recent, capable models. |
| 19 | + |
| 20 | +[Gemma 2](https://arxiv.org/abs/2408.00118) used a hybrid approach that combined local (sliding window) and global attention layers in a 1:1 ratio. Each token could attend to a context window of 4 k tokens. The reason for this 1:1 hybrid is that it strikes a balance between efficiency and global context modeling, since an LLM using only local attention can be too restrictive. |
| 21 | + |
| 22 | +[Gemma 3](https://arxiv.org/abs/2503.19786) then took the design further toward efficiency. It used a 5:1 ratio between sliding window and full attention layers, which means that for every five local attention layers, there is one global layer. In addition, the sliding window size was reduced from 4096 tokens in Gemma 2 to 1024 tokens in Gemma 3. |
| 23 | + |
| 24 | +Interestingly, the ablation studies in the Gemma 3 technical report indicate that these changes have only a minor effect on overall model quality. In other words, the substantial memory and compute savings achieved through sliding window attention come with minimal loss in modeling performance. |
| 25 | + |
| 26 | + |
| 27 | + |
| 28 | + |
| 29 | +## Sliding Window Attention (SWA) Memory Savings |
| 30 | + |
| 31 | +The memory savings are mostly reflected in the KV storage. We can compute the KV storage size with the following formula: |
| 32 | + |
| 33 | +bytes ≈ batch_size × seqlen × (embed_dim / n_heads) × n_layers × 2 (K,V) × bytes_per_elem × n_kv_heads |
| 34 | + |
| 35 | +When using SWA, we replace the sequence length (seqlen) above by the window size W. So, when using sliding window attention, we reduce the KV cache size by a factor of "W / seqlen". (Note that for simplicity, this assumes that sliding window attention is used in every layer.) |
| 36 | + |
| 37 | + |
| 38 | +You can use the [memory_estimator_swa.py](memory_estimator_swa.py) script in this folder to apply this for different model configs to see how much memory you can save by using SWA over MHA: |
| 39 | + |
| 40 | +```bash |
| 41 | +➜ uv run memory_estimator_swa.py \ |
| 42 | + --emb_dim 4096 --n_heads 32 --n_layers 32 \ |
| 43 | + --context_length 32768 --n_kv_groups 4 \ |
| 44 | + --batch_size 1 --dtype bf16 \ |
| 45 | + --sliding_window_size 1024 --swa_ratio "5:1" |
| 46 | +==== Config ==== |
| 47 | +context_length : 32768 |
| 48 | +sliding_window_size : 1024 |
| 49 | +emb_dim : 4096 |
| 50 | +n_heads : 32 |
| 51 | +n_layers : 32 |
| 52 | +n_kv_groups : 4 |
| 53 | +batch_size : 1 |
| 54 | +dtype : bf16 (2 Bytes/elem) |
| 55 | +head_dim : 128 |
| 56 | +GQA n_kv_heads : 8 |
| 57 | +Effective SWA window W : 1024 |
| 58 | +Layer ratio (SWA:Full) : 5:1 |
| 59 | +Distributed layers : 27 SWA, 5 FULL |
| 60 | + |
| 61 | +==== KV-cache totals across all layers ==== |
| 62 | +MHA KV total : 17.18 GB |
| 63 | +GQA KV total : 4.29 GB |
| 64 | +MHA + SWA (Ratio: 5:1) : 3.14 GB |
| 65 | +MHA + GQA (Ratio: 5:1) : 0.78 GB |
| 66 | +``` |
| 67 | + |
| 68 | +Note that Gemma 3 uses SWA in combination with GQA. |
| 69 | + |
| 70 | +The savings when using SWA over MHA are further shown in the plot below for different context lengths: |
| 71 | + |
| 72 | + |
| 73 | + |
| 74 | +<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/swa-memory/4.webp?2" alt="SWA" width="=800px" /> |
| 75 | + |
| 76 | + |
| 77 | + |
| 78 | +You can reproduce these plots via: |
| 79 | + |
| 80 | +```bash |
| 81 | +plot_memory_estimates_swa.py \ |
| 82 | + --emb_dim 4096 --n_heads 48 --n_layers 36 \ |
| 83 | + --batch_size 1 --dtype bf16 \ |
| 84 | + --sliding_window_size 2048 --swa_ratio "5:1" |
| 85 | +``` |
| 86 | + |
| 87 | + |
| 88 | + |
| 89 | +## SWA Code Examples |
| 90 | + |
| 91 | +The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_swa.py](gpt_with_kv_swa.py) scripts in this folder provide hands-on examples for comparing the MHA and SWA memory usage in the context of a GPT model implementation. |
| 92 | + |
| 93 | +Note that SWA can also be used in combination with MLA and GQA (as mentioned earlier), but for simplicity, this is not done here. |
| 94 | + |
| 95 | +Note that the model is not trained and thus generates nonsensical text. However, you can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train it. |
| 96 | + |
| 97 | +Also, this implementation uses the KV cache explained in [another bonus section](../03_kv-cache), so the memory savings are more pronounced. |
| 98 | + |
| 99 | +```bash |
| 100 | +uv run gpt_with_kv_mha.py \ |
| 101 | +--max_new_tokens 32768 \ |
| 102 | +--n_heads 24 \ |
| 103 | +--n_layers 12 \ |
| 104 | +--emb_dim 768 |
| 105 | + |
| 106 | +... |
| 107 | + |
| 108 | +Time: 453.81 sec |
| 109 | +72 tokens/sec |
| 110 | +Max memory allocated: 1.54 GB |
| 111 | +``` |
| 112 | + |
| 113 | +```bash |
| 114 | +uv run gpt_with_kv_swa.py \ |
| 115 | +--max_new_tokens 32768 \ |
| 116 | +--n_heads 24 \ |
| 117 | +--n_layers 12 \ |
| 118 | +--emb_dim 768 \ |
| 119 | +--sliding_window_size 1024 \ |
| 120 | +--sliding_window_stride 5 # like Gemma 3 |
| 121 | + |
| 122 | +... |
| 123 | + |
| 124 | +Time: 514.38 sec |
| 125 | +63 tokens/sec |
| 126 | +Max memory allocated: 0.63 GB |
| 127 | +``` |
| 128 | + |
| 129 | +The reason why we are not seeing such a big saving as in the plots above is 2-fold: |
| 130 | + |
| 131 | +1. I use a smaller configuration to have the model finish the generation in a reasonable time. |
| 132 | +2. More importantly, we are looking at the whole model here, not just the attention mechanism; the fully-connected layers in the model take up most of the memory (but this is a topic for a separate analysis). |
0 commit comments