Skip to content

Commit 6eb6adf

Browse files
authored
sliding window attention (#879)
1 parent 21f0617 commit 6eb6adf

File tree

11 files changed

+1456
-1
lines changed

11 files changed

+1456
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ appendix-E/01_main-chapter-code/loss-plot.pdf
1313

1414
ch04/04_gqa/kv_bytes_vs_context_length.pdf
1515
ch05/05_mla/kv_bytes_vs_context_length.pdf
16+
ch06/06_swa/kv_bytes_vs_context_length.pdf
1617

1718
ch05/01_main-chapter-code/loss-plot.pdf
1819
ch05/01_main-chapter-code/temperature-plot.pdf

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ Several folders contain optional materials as a bonus for interested readers:
170170
- [KV Cache](ch04/03_kv-cache)
171171
- [Grouped-Query Attention](ch04/04_gqa)
172172
- [Multi-Head Latent Attention](ch04/05_mla)
173+
- [Sliding Window Attention](ch04/06_swa)
173174
- **Chapter 5: Pretraining on unlabeled data:**
174175
- [Alternative Weight Loading Methods](ch05/02_alternative_weight_loading/)
175176
- [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg)

ch04/05_mla/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_mla.py](gpt_with_k
101101

102102
Here, the MLA code is inspired by the [https://huggingface.co/bird-of-paradise/deepseek-mla](https://huggingface.co/bird-of-paradise/deepseek-mla) implementation.
103103

104-
Note that MLA can also be used in combination with GQA, but for simplicity, I this is not done here. (Currently, I am also not aware of a prominent LLM doing this.)
104+
Note that MLA can also be used in combination with [GQA](../04_gqa), but for simplicity, I this is not done here. (Currently, I am also not aware of a prominent LLM doing this.)
105105

106106
Also note that the model is not trained and thus generates nonsensical text. However, you can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train it.
107107

ch04/06_swa/README.md

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Sliding Window Attention (SWA)
2+
3+
This bonus material illustrates the memory savings when using Sliding Window Attention (SWA) over regular Multi-Head Attention (MHA).
4+
5+
6+
7+
 
8+
## Introduction
9+
10+
What is sliding window attention (SWA)? If we think of regular self-attention as a *global* attention mechanism, since each sequence element can access every other sequence element, then we can think of SWA as *local* attention, because here we restrict the context size around the current query position. This is illustrated in the figure below.
11+
12+
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/swa-memory/1.webp?2" alt="Sliding Window Attention" width="500px" />
13+
14+
As shown in the figure above, instead of attending to all previous tokens, each token only attends to a fixed-size local window around its position. This localized attention lowers the size of the KV cache substantially.
15+
16+
In the remainder of this introduction, we will discuss SWA in the context of [Gemma 3](https://arxiv.org/abs/2503.19786), which is implemented from scratch in [../../ch05/12_gemma3](../../ch05/12_gemma3).
17+
18+
Sliding window attention was originally introduced in the [LongFormer paper in 2020](https://arxiv.org/abs/2004.05150), but the reason we focus on Google's Gemma models is that they are very good open-weight models showing that sliding window attention is indeed a feasible approach in recent, capable models.
19+
20+
[Gemma 2](https://arxiv.org/abs/2408.00118) used a hybrid approach that combined local (sliding window) and global attention layers in a 1:1 ratio. Each token could attend to a context window of 4 k tokens. The reason for this 1:1 hybrid is that it strikes a balance between efficiency and global context modeling, since an LLM using only local attention can be too restrictive.
21+
22+
[Gemma 3](https://arxiv.org/abs/2503.19786) then took the design further toward efficiency. It used a 5:1 ratio between sliding window and full attention layers, which means that for every five local attention layers, there is one global layer. In addition, the sliding window size was reduced from 4096 tokens in Gemma 2 to 1024 tokens in Gemma 3.
23+
24+
Interestingly, the ablation studies in the Gemma 3 technical report indicate that these changes have only a minor effect on overall model quality. In other words, the substantial memory and compute savings achieved through sliding window attention come with minimal loss in modeling performance.
25+
26+
27+
28+
&nbsp;
29+
## Sliding Window Attention (SWA) Memory Savings
30+
31+
The memory savings are mostly reflected in the KV storage. We can compute the KV storage size with the following formula:
32+
33+
bytes ≈ batch_size × seqlen × (embed_dim / n_heads) × n_layers × 2 (K,V) × bytes_per_elem × n_kv_heads
34+
35+
When using SWA, we replace the sequence length (seqlen) above by the window size W. So, when using sliding window attention, we reduce the KV cache size by a factor of "W / seqlen". (Note that for simplicity, this assumes that sliding window attention is used in every layer.)
36+
37+
38+
You can use the [memory_estimator_swa.py](memory_estimator_swa.py) script in this folder to apply this for different model configs to see how much memory you can save by using SWA over MHA:
39+
40+
```bash
41+
➜ uv run memory_estimator_swa.py \
42+
--emb_dim 4096 --n_heads 32 --n_layers 32 \
43+
--context_length 32768 --n_kv_groups 4 \
44+
--batch_size 1 --dtype bf16 \
45+
--sliding_window_size 1024 --swa_ratio "5:1"
46+
==== Config ====
47+
context_length : 32768
48+
sliding_window_size : 1024
49+
emb_dim : 4096
50+
n_heads : 32
51+
n_layers : 32
52+
n_kv_groups : 4
53+
batch_size : 1
54+
dtype : bf16 (2 Bytes/elem)
55+
head_dim : 128
56+
GQA n_kv_heads : 8
57+
Effective SWA window W : 1024
58+
Layer ratio (SWA:Full) : 5:1
59+
Distributed layers : 27 SWA, 5 FULL
60+
61+
==== KV-cache totals across all layers ====
62+
MHA KV total : 17.18 GB
63+
GQA KV total : 4.29 GB
64+
MHA + SWA (Ratio: 5:1) : 3.14 GB
65+
MHA + GQA (Ratio: 5:1) : 0.78 GB
66+
```
67+
68+
Note that Gemma 3 uses SWA in combination with GQA.
69+
70+
The savings when using SWA over MHA are further shown in the plot below for different context lengths:
71+
72+
&nbsp;
73+
74+
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/swa-memory/4.webp?2" alt="SWA" width="=800px" />
75+
76+
&nbsp;
77+
78+
You can reproduce these plots via:
79+
80+
```bash
81+
plot_memory_estimates_swa.py \
82+
--emb_dim 4096 --n_heads 48 --n_layers 36 \
83+
--batch_size 1 --dtype bf16 \
84+
--sliding_window_size 2048 --swa_ratio "5:1"
85+
```
86+
87+
88+
&nbsp;
89+
## SWA Code Examples
90+
91+
The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_swa.py](gpt_with_kv_swa.py) scripts in this folder provide hands-on examples for comparing the MHA and SWA memory usage in the context of a GPT model implementation.
92+
93+
Note that SWA can also be used in combination with MLA and GQA (as mentioned earlier), but for simplicity, this is not done here.
94+
95+
Note that the model is not trained and thus generates nonsensical text. However, you can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train it.
96+
97+
Also, this implementation uses the KV cache explained in [another bonus section](../03_kv-cache), so the memory savings are more pronounced.
98+
99+
```bash
100+
uv run gpt_with_kv_mha.py \
101+
--max_new_tokens 32768 \
102+
--n_heads 24 \
103+
--n_layers 12 \
104+
--emb_dim 768
105+
106+
...
107+
108+
Time: 453.81 sec
109+
72 tokens/sec
110+
Max memory allocated: 1.54 GB
111+
```
112+
113+
```bash
114+
uv run gpt_with_kv_swa.py \
115+
--max_new_tokens 32768 \
116+
--n_heads 24 \
117+
--n_layers 12 \
118+
--emb_dim 768 \
119+
--sliding_window_size 1024 \
120+
--sliding_window_stride 5 # like Gemma 3
121+
122+
...
123+
124+
Time: 514.38 sec
125+
63 tokens/sec
126+
Max memory allocated: 0.63 GB
127+
```
128+
129+
The reason why we are not seeing such a big saving as in the plots above is 2-fold:
130+
131+
1. I use a smaller configuration to have the model finish the generation in a reasonable time.
132+
2. More importantly, we are looking at the whole model here, not just the attention mechanism; the fully-connected layers in the model take up most of the memory (but this is a topic for a separate analysis).

0 commit comments

Comments
 (0)