diff --git a/.gitignore b/.gitignore index d83f522fb..ef8f7fec5 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ appendix-E/01_main-chapter-code/loss-plot.pdf ch04/04_gqa/kv_bytes_vs_context_length.pdf ch05/05_mla/kv_bytes_vs_context_length.pdf +ch06/06_swa/kv_bytes_vs_context_length.pdf ch05/01_main-chapter-code/loss-plot.pdf ch05/01_main-chapter-code/temperature-plot.pdf diff --git a/README.md b/README.md index bea937190..7805a3426 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,7 @@ Several folders contain optional materials as a bonus for interested readers: - [KV Cache](ch04/03_kv-cache) - [Grouped-Query Attention](ch04/04_gqa) - [Multi-Head Latent Attention](ch04/05_mla) + - [Sliding Window Attention](ch04/06_swa) - **Chapter 5: Pretraining on unlabeled data:** - [Alternative Weight Loading Methods](ch05/02_alternative_weight_loading/) - [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg) diff --git a/ch04/05_mla/README.md b/ch04/05_mla/README.md index 25d0e19d0..51deb2456 100644 --- a/ch04/05_mla/README.md +++ b/ch04/05_mla/README.md @@ -101,7 +101,7 @@ The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_mla.py](gpt_with_k Here, the MLA code is inspired by the [https://huggingface.co/bird-of-paradise/deepseek-mla](https://huggingface.co/bird-of-paradise/deepseek-mla) implementation. -Note that MLA can also be used in combination with GQA, but for simplicity, I this is not done here. (Currently, I am also not aware of a prominent LLM doing this.) +Note that MLA can also be used in combination with [GQA](../04_gqa), but for simplicity, I this is not done here. (Currently, I am also not aware of a prominent LLM doing this.) Also note that the model is not trained and thus generates nonsensical text. However, you can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train it. diff --git a/ch04/06_swa/README.md b/ch04/06_swa/README.md new file mode 100644 index 000000000..0520e7546 --- /dev/null +++ b/ch04/06_swa/README.md @@ -0,0 +1,132 @@ +# Sliding Window Attention (SWA) + +This bonus material illustrates the memory savings when using Sliding Window Attention (SWA) over regular Multi-Head Attention (MHA). + + + +  +## Introduction + +What is sliding window attention (SWA)? If we think of regular self-attention as a *global* attention mechanism, since each sequence element can access every other sequence element, then we can think of SWA as *local* attention, because here we restrict the context size around the current query position. This is illustrated in the figure below. + +Sliding Window Attention + +As shown in the figure above, instead of attending to all previous tokens, each token only attends to a fixed-size local window around its position. This localized attention lowers the size of the KV cache substantially. + +In the remainder of this introduction, we will discuss SWA in the context of [Gemma 3](https://arxiv.org/abs/2503.19786), which is implemented from scratch in [../../ch05/12_gemma3](../../ch05/12_gemma3). + +Sliding window attention was originally introduced in the [LongFormer paper in 2020](https://arxiv.org/abs/2004.05150), but the reason we focus on Google's Gemma models is that they are very good open-weight models showing that sliding window attention is indeed a feasible approach in recent, capable models. + +[Gemma 2](https://arxiv.org/abs/2408.00118) used a hybrid approach that combined local (sliding window) and global attention layers in a 1:1 ratio. Each token could attend to a context window of 4 k tokens. The reason for this 1:1 hybrid is that it strikes a balance between efficiency and global context modeling, since an LLM using only local attention can be too restrictive. + +[Gemma 3](https://arxiv.org/abs/2503.19786) then took the design further toward efficiency. It used a 5:1 ratio between sliding window and full attention layers, which means that for every five local attention layers, there is one global layer. In addition, the sliding window size was reduced from 4096 tokens in Gemma 2 to 1024 tokens in Gemma 3. + +Interestingly, the ablation studies in the Gemma 3 technical report indicate that these changes have only a minor effect on overall model quality. In other words, the substantial memory and compute savings achieved through sliding window attention come with minimal loss in modeling performance. + + + +  +## Sliding Window Attention (SWA) Memory Savings + +The memory savings are mostly reflected in the KV storage. We can compute the KV storage size with the following formula: + +bytes ≈ batch_size × seqlen × (embed_dim / n_heads) × n_layers × 2 (K,V) × bytes_per_elem × n_kv_heads + +When using SWA, we replace the sequence length (seqlen) above by the window size W. So, when using sliding window attention, we reduce the KV cache size by a factor of "W / seqlen". (Note that for simplicity, this assumes that sliding window attention is used in every layer.) + + +You can use the [memory_estimator_swa.py](memory_estimator_swa.py) script in this folder to apply this for different model configs to see how much memory you can save by using SWA over MHA: + +```bash +➜ uv run memory_estimator_swa.py \ + --emb_dim 4096 --n_heads 32 --n_layers 32 \ + --context_length 32768 --n_kv_groups 4 \ + --batch_size 1 --dtype bf16 \ + --sliding_window_size 1024 --swa_ratio "5:1" +==== Config ==== +context_length : 32768 +sliding_window_size : 1024 +emb_dim : 4096 +n_heads : 32 +n_layers : 32 +n_kv_groups : 4 +batch_size : 1 +dtype : bf16 (2 Bytes/elem) +head_dim : 128 +GQA n_kv_heads : 8 +Effective SWA window W : 1024 +Layer ratio (SWA:Full) : 5:1 +Distributed layers : 27 SWA, 5 FULL + +==== KV-cache totals across all layers ==== +MHA KV total : 17.18 GB +GQA KV total : 4.29 GB +MHA + SWA (Ratio: 5:1) : 3.14 GB +MHA + GQA (Ratio: 5:1) : 0.78 GB +``` + +Note that Gemma 3 uses SWA in combination with GQA. + +The savings when using SWA over MHA are further shown in the plot below for different context lengths: + +  + +SWA + +  + +You can reproduce these plots via: + +```bash +plot_memory_estimates_swa.py \ + --emb_dim 4096 --n_heads 48 --n_layers 36 \ + --batch_size 1 --dtype bf16 \ + --sliding_window_size 2048 --swa_ratio "5:1" +``` + + +  +## SWA Code Examples + +The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_swa.py](gpt_with_kv_swa.py) scripts in this folder provide hands-on examples for comparing the MHA and SWA memory usage in the context of a GPT model implementation. + +Note that SWA can also be used in combination with MLA and GQA (as mentioned earlier), but for simplicity, this is not done here. + +Note that the model is not trained and thus generates nonsensical text. However, you can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train it. + +Also, this implementation uses the KV cache explained in [another bonus section](../03_kv-cache), so the memory savings are more pronounced. + +```bash +uv run gpt_with_kv_mha.py \ +--max_new_tokens 32768 \ +--n_heads 24 \ +--n_layers 12 \ +--emb_dim 768 + +... + +Time: 453.81 sec +72 tokens/sec +Max memory allocated: 1.54 GB +``` + +```bash +uv run gpt_with_kv_swa.py \ +--max_new_tokens 32768 \ +--n_heads 24 \ +--n_layers 12 \ +--emb_dim 768 \ +--sliding_window_size 1024 \ +--sliding_window_stride 5 # like Gemma 3 + +... + +Time: 514.38 sec +63 tokens/sec +Max memory allocated: 0.63 GB +``` + +The reason why we are not seeing such a big saving as in the plots above is 2-fold: + +1. I use a smaller configuration to have the model finish the generation in a reasonable time. +2. More importantly, we are looking at the whole model here, not just the attention mechanism; the fully-connected layers in the model take up most of the memory (but this is a topic for a separate analysis). diff --git a/ch04/06_swa/gpt_with_kv_mha.py b/ch04/06_swa/gpt_with_kv_mha.py new file mode 100644 index 000000000..f906d71b2 --- /dev/null +++ b/ch04/06_swa/gpt_with_kv_mha.py @@ -0,0 +1,344 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# This file collects all the relevant code that we covered thus far +# throughout Chapters 3-4. +# This file can be run as a standalone script. + +import argparse +import time +import tiktoken +import torch +import torch.nn as nn + + +##################################### +# Chapter 3 +##################################### +class MultiHeadAttention(nn.Module): + def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim + + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) + self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs + self.dropout = nn.Dropout(dropout) + + #################################################### + # KV cache-related code + self.register_buffer("cache_k", None, persistent=False) + self.register_buffer("cache_v", None, persistent=False) + self.ptr_current_pos = 0 + #################################################### + + def forward(self, x, use_cache=False): + b, num_tokens, d_in = x.shape + + keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out) + values_new = self.W_value(x) + queries = self.W_query(x) + + # We implicitly split the matrix by adding a `num_heads` dimension + # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) + keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim) + values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim) + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) + + #################################################### + # KV cache-related + if use_cache: + if self.cache_k is None: + self.cache_k, self.cache_v = keys_new, values_new + else: + self.cache_k = torch.cat([self.cache_k, keys_new], dim=1) + self.cache_v = torch.cat([self.cache_v, values_new], dim=1) + keys, values = self.cache_k, self.cache_v + else: + keys, values = keys_new, values_new + #################################################### + + # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) + keys = keys.transpose(1, 2) + queries = queries.transpose(1, 2) + values = values.transpose(1, 2) + + # Compute scaled dot-product attention (aka self-attention) with a causal mask + attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head + + #################################################### + # causal mask + num_tokens_Q = queries.shape[-2] + num_tokens_K = keys.shape[-2] + device = queries.device + if use_cache: + q_positions = torch.arange( + self.ptr_current_pos, + self.ptr_current_pos + num_tokens_Q, + device=device, + dtype=torch.long, + ) + self.ptr_current_pos += num_tokens_Q + else: + q_positions = torch.arange(num_tokens_Q, device=device, dtype=torch.long) + self.ptr_current_pos = 0 + k_positions = torch.arange(num_tokens_K, device=device, dtype=torch.long) + mask_bool = q_positions.unsqueeze(-1) < k_positions.unsqueeze(0) + + # Use the mask to fill attention scores + attn_scores.masked_fill_(mask_bool, -torch.inf) + + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = (attn_weights @ values).transpose(1, 2) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) + context_vec = self.out_proj(context_vec) # optional projection + + return context_vec + + def reset_cache(self): + self.cache_k, self.cache_v = None, None + self.ptr_current_pos = 0 + + +##################################### +# Chapter 4 +##################################### +class LayerNorm(nn.Module): + def __init__(self, emb_dim): + super().__init__() + self.eps = 1e-5 + self.scale = nn.Parameter(torch.ones(emb_dim)) + self.shift = nn.Parameter(torch.zeros(emb_dim)) + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + norm_x = (x - mean) / torch.sqrt(var + self.eps) + return self.scale * norm_x + self.shift + + +class GELU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + torch.sqrt(torch.tensor(2.0 / torch.pi)) * + (x + 0.044715 * torch.pow(x, 3)) + )) + + +class FeedForward(nn.Module): + def __init__(self, cfg): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), + GELU(), + nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), + ) + + def forward(self, x): + return self.layers(x) + + +class TransformerBlock(nn.Module): + def __init__(self, cfg): + super().__init__() + self.att = MultiHeadAttention( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + num_heads=cfg["n_heads"], + dropout=cfg["drop_rate"], + qkv_bias=cfg["qkv_bias"]) + self.ff = FeedForward(cfg) + self.norm1 = LayerNorm(cfg["emb_dim"]) + self.norm2 = LayerNorm(cfg["emb_dim"]) + self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) + + def forward(self, x, use_cache=False): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + + # x = self.att(x) # Shape [batch_size, num_tokens, emb_size] + #################################################### + # KV cache-related + x = self.att(x, use_cache=use_cache) + #################################################### + + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + x = self.ff(x) + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + return x + + +class GPTModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) + self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) + + # self.trf_blocks = nn.Sequential( + # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + #################################################### + # KV cache-related + self.trf_blocks = nn.ModuleList( + [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + + self.current_pos = 0 + #################################################### + + self.final_norm = LayerNorm(cfg["emb_dim"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + + def forward(self, in_idx, use_cache=False): + batch_size, seq_len = in_idx.shape + tok_embeds = self.tok_emb(in_idx) + + # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) + + #################################################### + # KV cache-related + if use_cache: + pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long) + self.current_pos += seq_len + else: + pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long) + pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) + #################################################### + + x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] + x = self.drop_emb(x) + + # x = self.trf_blocks(x) + #################################################### + # KV cache-related + for blk in self.trf_blocks: + x = blk(x, use_cache=use_cache) + #################################################### + + x = self.final_norm(x) + logits = self.out_head(x) + return logits + + #################################################### + # KV cache-related + def reset_kv_cache(self): + for blk in self.trf_blocks: + blk.att.reset_cache() + self.current_pos = 0 + #################################################### + + +def generate_text_simple_cached(model, idx, max_new_tokens, + context_size=None, use_cache=True): + model.eval() + ctx_len = context_size or model.pos_emb.num_embeddings + + with torch.no_grad(): + if use_cache: + # Init cache with full prompt + model.reset_kv_cache() + logits = model(idx[:, -ctx_len:], use_cache=True) + + for _ in range(max_new_tokens): + # a) pick the token with the highest log-probability (greedy sampling) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + # b) append it to the running sequence + idx = torch.cat([idx, next_idx], dim=1) + # c) feed model only the new token + logits = model(next_idx, use_cache=True) + else: + for _ in range(max_new_tokens): + logits = model(idx[:, -ctx_len:], use_cache=False) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) + + return idx + + +def main(): + parser = argparse.ArgumentParser(description="Run GPT with standard multi-head attention.") + parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.") + parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.") + parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.") + parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.") + + args = parser.parse_args() + + start_context = "Hello, I am" + tokenizer = tiktoken.get_encoding("gpt2") + encoded = tokenizer.encode(start_context) + + GPT_CONFIG_124M = { + "vocab_size": 50257, # Vocabulary size + "context_length": args.max_new_tokens + len(encoded), + "emb_dim": args.emb_dim, # Embedding dimension + "n_heads": args.n_heads, # Number of attention heads + "n_layers": args.n_layers, # Number of layers + "drop_rate": 0.0, # Dropout rate + "qkv_bias": False, # Query-Key-Value bias + } + torch.manual_seed(123) + model = GPTModel(GPT_CONFIG_124M) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device, dtype=torch.bfloat16) + model.eval() # disable dropout + + encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0) + print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}") + print("\nInput text:", start_context) + print("Encoded input text:", encoded) + print("encoded_tensor.shape:", encoded_tensor.shape) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = time.time() + + token_ids = generate_text_simple_cached( + model=model, + idx=encoded_tensor, + max_new_tokens=args.max_new_tokens, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + total_time = time.time() - start + + decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist()) + + print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}") + print("\nOutput:", token_ids) + print("Output length:", len(token_ids[0])) + print("Output text:", decoded_text) + + print(f"\nTime: {total_time:.2f} sec") + print(f"{int(len(token_ids[0])/total_time)} tokens/sec") + if torch.cuda.is_available(): + max_mem_bytes = torch.cuda.max_memory_allocated() + max_mem_gb = max_mem_bytes / (1024 ** 3) + print(f"Max memory allocated: {max_mem_gb:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/ch04/06_swa/gpt_with_kv_swa.py b/ch04/06_swa/gpt_with_kv_swa.py new file mode 100644 index 000000000..bd4cda744 --- /dev/null +++ b/ch04/06_swa/gpt_with_kv_swa.py @@ -0,0 +1,381 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# This file collects all the relevant code that we covered thus far +# throughout Chapters 3-4. +# This file can be run as a standalone script. + +import argparse +import time +import tiktoken +import torch +import torch.nn as nn + + +##################################### +# Chapter 3 +##################################### +class MultiHeadAttentionWithSWA(nn.Module): + def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False, sliding_window_size=None): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by num_heads" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim + + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) + self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs + self.dropout = nn.Dropout(dropout) + self.sliding_window_size = sliding_window_size + + #################################################### + # KV cache-related code + self.register_buffer("cache_k", None, persistent=False) + self.register_buffer("cache_v", None, persistent=False) + self.ptr_current_pos = 0 + #################################################### + + def forward(self, x, use_cache=False): + b, num_tokens, d_in = x.shape + + keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out) + values_new = self.W_value(x) + queries = self.W_query(x) + + # We implicitly split the matrix by adding a `num_heads` dimension + # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) + keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim) + values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim) + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) + + #################################################### + # KV cache-related + if use_cache: + old_len = 0 if self.cache_k is None else self.cache_k.size(1) + if self.cache_k is None: + self.cache_k, self.cache_v = keys_new, values_new + else: + self.cache_k = torch.cat([self.cache_k, keys_new], dim=1) + self.cache_v = torch.cat([self.cache_v, values_new], dim=1) + # Left-trim to sliding window if configured + if self.sliding_window_size is not None: + if self.cache_k.size(1) > self.sliding_window_size: + self.cache_k = self.cache_k[:, -self.sliding_window_size:, :, :] + self.cache_v = self.cache_v[:, -self.sliding_window_size:, :, :] + # Compute absolute start positions for mask + total_len = old_len + num_tokens + k_len_now = self.cache_k.size(1) + dropped = max(0, total_len - k_len_now) + k_start_pos_abs = (self.ptr_current_pos - old_len) + dropped + q_start_pos_abs = self.ptr_current_pos + keys, values = self.cache_k, self.cache_v + else: + keys, values = keys_new, values_new + #################################################### + + # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) + keys = keys.transpose(1, 2) + queries = queries.transpose(1, 2) + values = values.transpose(1, 2) + + # Compute scaled dot-product attention (aka self-attention) with a causal mask + attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head + + #################################################### + # causal + sliding-window mask + num_tokens_Q = queries.shape[-2] + num_tokens_K = keys.shape[-2] + device = queries.device + # Determine absolute positions for q and k + if use_cache: + q_start = q_start_pos_abs + k_start = k_start_pos_abs + else: + q_start = 0 + k_start = 0 + q_positions = torch.arange(q_start, q_start + num_tokens_Q, device=device, dtype=torch.long) + k_positions = torch.arange(k_start, k_start + num_tokens_K, device=device, dtype=torch.long) + # Sliding window width + W = num_tokens_K + 1 if self.sliding_window_size is None else int(self.sliding_window_size) + diff = q_positions.unsqueeze(-1) - k_positions.unsqueeze(0) + mask_bool = (diff < 0) | (diff >= W) + if use_cache: + self.ptr_current_pos += num_tokens_Q + else: + self.ptr_current_pos = 0 + + # Use the mask to fill attention scores + attn_scores.masked_fill_(mask_bool, -torch.inf) + + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = (attn_weights @ values).transpose(1, 2) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) + context_vec = self.out_proj(context_vec) # optional projection + + return context_vec + + def reset_cache(self): + self.cache_k, self.cache_v = None, None + self.ptr_current_pos = 0 + + +##################################### +# Chapter 4 +##################################### +class LayerNorm(nn.Module): + def __init__(self, emb_dim): + super().__init__() + self.eps = 1e-5 + self.scale = nn.Parameter(torch.ones(emb_dim)) + self.shift = nn.Parameter(torch.zeros(emb_dim)) + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + norm_x = (x - mean) / torch.sqrt(var + self.eps) + return self.scale * norm_x + self.shift + + +class GELU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + torch.sqrt(torch.tensor(2.0 / torch.pi)) * + (x + 0.044715 * torch.pow(x, 3)) + )) + + +class FeedForward(nn.Module): + def __init__(self, cfg): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), + GELU(), + nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), + ) + + def forward(self, x): + return self.layers(x) + + +class TransformerBlock(nn.Module): + def __init__(self, cfg): + super().__init__() + self.att = MultiHeadAttentionWithSWA( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + num_heads=cfg["n_heads"], + dropout=cfg["drop_rate"], + qkv_bias=cfg["qkv_bias"], + sliding_window_size=cfg["sliding_window_size"], + ) + self.ff = FeedForward(cfg) + self.norm1 = LayerNorm(cfg["emb_dim"]) + self.norm2 = LayerNorm(cfg["emb_dim"]) + self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) + + def forward(self, x, use_cache=False): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + + # x = self.att(x) # Shape [batch_size, num_tokens, emb_size] + #################################################### + # KV cache-related + x = self.att(x, use_cache=use_cache) + #################################################### + + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + x = self.ff(x) + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + return x + + +class GPTModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) + self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) + + # self.trf_blocks = nn.Sequential( + # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + #################################################### + # KV cache-related + blocks = [] + window_stride = cfg["sliding_window_stride"] + window_size = cfg["sliding_window_size"] if "sliding_window_size" in cfg else None + for i in range(cfg["n_layers"]): + blk = TransformerBlock(cfg) + # K:1 schedule meaning that K SWA layers are followed by 1 regular layer + K = int(window_stride) + if K <= 0: + # 0 => all regular; negative => all SWA + use_swa = False if K == 0 else True + else: + group = K + 1 + use_swa = (i % group) < K + blk.att.sliding_window_size = window_size if use_swa else None + blocks.append(blk) + self.trf_blocks = nn.ModuleList(blocks) + + self.current_pos = 0 + #################################################### + + self.final_norm = LayerNorm(cfg["emb_dim"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + + def forward(self, in_idx, use_cache=False): + batch_size, seq_len = in_idx.shape + tok_embeds = self.tok_emb(in_idx) + + # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) + + #################################################### + # KV cache-related + if use_cache: + pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long) + self.current_pos += seq_len + else: + pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long) + pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) + #################################################### + + x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] + x = self.drop_emb(x) + + # x = self.trf_blocks(x) + #################################################### + # KV cache-related + for blk in self.trf_blocks: + x = blk(x, use_cache=use_cache) + #################################################### + + x = self.final_norm(x) + logits = self.out_head(x) + return logits + + #################################################### + # KV cache-related + def reset_kv_cache(self): + for blk in self.trf_blocks: + blk.att.reset_cache() + self.current_pos = 0 + #################################################### + + +def generate_text_simple_cached(model, idx, max_new_tokens, + context_size=None, use_cache=True): + model.eval() + ctx_len = context_size or model.pos_emb.num_embeddings + + with torch.no_grad(): + if use_cache: + # Init cache with full prompt + model.reset_kv_cache() + logits = model(idx[:, -ctx_len:], use_cache=True) + + for _ in range(max_new_tokens): + # a) pick the token with the highest log-probability (greedy sampling) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + # b) append it to the running sequence + idx = torch.cat([idx, next_idx], dim=1) + # c) feed model only the new token + logits = model(next_idx, use_cache=True) + else: + for _ in range(max_new_tokens): + logits = model(idx[:, -ctx_len:], use_cache=False) + next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) + idx = torch.cat([idx, next_idx], dim=1) + + return idx + + +def main(): + parser = argparse.ArgumentParser(description="Run GPT with standard multi-head attention.") + parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.") + parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.") + parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.") + parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.") + parser.add_argument("--sliding_window_size", type=int, default=1024, help="Window size for sliding window attention.") + parser.add_argument("--sliding_window_stride", type=int, default=2, help="K:1 frequency sliding window attention is applied. K=5 means 5 sliding window layers follows by a regular layer.") + + args = parser.parse_args() + + start_context = "Hello, I am" + tokenizer = tiktoken.get_encoding("gpt2") + encoded = tokenizer.encode(start_context) + + GPT_CONFIG_124M = { + "vocab_size": 50257, # Vocabulary size + "context_length": args.max_new_tokens + len(encoded), + "emb_dim": args.emb_dim, # Embedding dimension + "n_heads": args.n_heads, # Number of attention heads + "n_layers": args.n_layers, # Number of layers + "drop_rate": 0.0, # Dropout rate + "qkv_bias": False, # Query-Key-Value bias + "sliding_window_size": args.sliding_window_size, + "sliding_window_stride": args.sliding_window_stride + } + torch.manual_seed(123) + model = GPTModel(GPT_CONFIG_124M) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device, dtype=torch.bfloat16) + model.eval() # disable dropout + + encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0) + print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}") + print("\nInput text:", start_context) + print("Encoded input text:", encoded) + print("encoded_tensor.shape:", encoded_tensor.shape) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = time.time() + + token_ids = generate_text_simple_cached( + model=model, + idx=encoded_tensor, + max_new_tokens=args.max_new_tokens, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + total_time = time.time() - start + + decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist()) + + print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}") + print("\nOutput:", token_ids) + print("Output length:", len(token_ids[0])) + print("Output text:", decoded_text) + + print(f"\nTime: {total_time:.2f} sec") + print(f"{int(len(token_ids[0])/total_time)} tokens/sec") + if torch.cuda.is_available(): + max_mem_bytes = torch.cuda.max_memory_allocated() + max_mem_gb = max_mem_bytes / (1024 ** 3) + print(f"Max memory allocated: {max_mem_gb:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/ch04/06_swa/memory_estimator_mla.py b/ch04/06_swa/memory_estimator_mla.py new file mode 100644 index 000000000..f9ab9f512 --- /dev/null +++ b/ch04/06_swa/memory_estimator_mla.py @@ -0,0 +1,123 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch +# +# KV-cache memory estimator for MHA vs GQA vs MLA + +import argparse +import math + +DTYPE_BYTES = { + "fp32": 4, + "bf16": 2, + "fp16": 2, + "fp8": 1, + "int8": 1, +} + + +def bytes_convert(n): + gb = n / (1000 ** 3) + return f"{gb:,.2f} GB" + + +def kv_bytes_total(batch, context_length, emb_dim, n_heads, + n_kv_heads, n_layers, bytes_per_elem): + # Generic KV-cache: per-head dim is embed_dim / n_heads, times 2 for K and V + head_dim = math.ceil(emb_dim / n_heads) + per_layer = batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem + return per_layer * n_layers + + +def mla_bytes_total(batch, context_length, n_layers, latent_dim, bytes_per_elem): + # Simple MLA (per-token compressed latent) + # bytes ≈ batch × seqlen × n_layers × latent_dim × bytes_per_elem + return batch * context_length * n_layers * latent_dim * bytes_per_elem + + +def main(): + p = argparse.ArgumentParser(description="Estimate KV-cache memory for MHA vs GQA vs MLA") + p.add_argument("--context_length", default=1024, type=int) + p.add_argument("--emb_dim", required=True, type=int) + p.add_argument("--n_heads", required=True, type=int) + p.add_argument("--n_layers", required=True, type=int) + p.add_argument("--n_kv_groups", required=True, type=int) + p.add_argument("--latent_dim", required=True, type=int, help="MLA per-token latent dimension") + p.add_argument("--batch_size", default=1, type=int) + p.add_argument("--dtype", choices=DTYPE_BYTES.keys(), default="fp16") + args = p.parse_args() + + cfg = { + "context_length": args.context_length, + "emb_dim": args.emb_dim, + "n_heads": args.n_heads, + "n_layers": args.n_layers, + "n_kv_groups": args.n_kv_groups, + "latent_dim": args.latent_dim, + } + + if cfg["n_heads"] % cfg["n_kv_groups"] != 0: + raise ValueError("n_kv_groups must divide n_heads exactly.") + + bytes_per_elem = DTYPE_BYTES[args.dtype] + head_dim = math.ceil(cfg["emb_dim"] / cfg["n_heads"]) + + n_kv_heads_mha = cfg["n_heads"] + n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"] + + total_mha = kv_bytes_total( + args.batch_size, + cfg["context_length"], + cfg["emb_dim"], + cfg["n_heads"], + n_kv_heads_mha, + cfg["n_layers"], + bytes_per_elem, + ) + + total_gqa = kv_bytes_total( + args.batch_size, + cfg["context_length"], + cfg["emb_dim"], + cfg["n_heads"], + n_kv_heads_gqa, + cfg["n_layers"], + bytes_per_elem, + ) + + total_mla = mla_bytes_total( + args.batch_size, + cfg["context_length"], + cfg["n_layers"], + cfg["latent_dim"], + bytes_per_elem, + ) + + ratio = total_mha / total_gqa if total_gqa != 0 else float("inf") + savings = 1 - (total_gqa / total_mha) if total_mha != 0 else 0.0 + + ratio_mha_mla = total_mha / total_mla if total_mla != 0 else float("inf") + savings_mla = 1 - (total_mla / total_mha) if total_mha != 0 else 0.0 + + print("==== Config ====") + for k, v in cfg.items(): + print(f"{k:17}: {v}") + print(f"batch_size : {args.batch_size}") + print(f"dtype : {args.dtype} ({bytes_per_elem} Bytes/elem)") + print(f"head_dim : {head_dim}") + print(f"GQA n_kv_heads : {n_kv_heads_gqa}") + print() + + print("==== KV-cache totals across all layers ====") + print(f"MHA total KV cache : {bytes_convert(total_mha)}") + print(f"GQA total KV cache : {bytes_convert(total_gqa)}") + print(f"MLA total KV cache : {bytes_convert(total_mla)}") + print(f"Ratio (MHA / GQA) : {ratio:,.2f}x") + print(f"Savings (GQA vs MHA): {savings*100:,.2f}%") + print(f"Ratio (MHA / MLA) : {ratio_mha_mla:,.2f}x") + print(f"Savings (MLA vs MHA): {savings_mla*100:,.2f}%") + + +if __name__ == "__main__": + main() diff --git a/ch04/06_swa/memory_estimator_swa.py b/ch04/06_swa/memory_estimator_swa.py new file mode 100644 index 000000000..2686c286f --- /dev/null +++ b/ch04/06_swa/memory_estimator_swa.py @@ -0,0 +1,151 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch +# +# KV-cache memory estimator for MHA vs GQA with SWA. + +import argparse +import math + +DTYPE_BYTES = { + "fp32": 4, + "bf16": 2, + "fp16": 2, + "fp8": 1, + "int8": 1, +} + + +def bytes_convert(n): + gb = n / (1000 ** 3) + return f"{gb:,.2f} GB" + + +def kv_bytes_per_layer(batch, context_length, head_dim, n_kv_heads, bytes_per_elem): + # KV = batch * tokens * head_dim * n_kv_heads * 2 (K,V) * bytes + return batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem + + +def parse_ratio(ratio_str): + # "--swa_ratio a:b" means a SWA layers for every b full layers within a block + try: + a_str, b_str = ratio_str.split(":") + a, b = int(a_str), int(b_str) + assert a >= 0 and b >= 0 and (a + b) > 0 + return a, b + except Exception: + raise ValueError("--swa_ratio must be in the form 'a:b' with nonnegative integers and a+b>0") + + +def distribute_layers(n_layers, a, b): + block = a + b + blocks = n_layers // block + rem = n_layers % block + swa = blocks * a + min(a, rem) + full = blocks * b + max(0, rem - a) + return swa, full + + +def estimate_totals(context_length, sliding_window_size, emb_dim, n_heads, n_layers, + n_kv_groups, batch_size, dtype, swa_ratio): + if n_heads % n_kv_groups != 0: + raise ValueError("n_kv_groups must divide n_heads exactly.") + + bytes_per_elem = DTYPE_BYTES[dtype] + head_dim = math.ceil(emb_dim / n_heads) + n_kv_heads_mha = n_heads + n_kv_heads_gqa = n_heads // n_kv_groups + + a_swa, b_full = parse_ratio(swa_ratio) + n_swa_layers, n_full_layers = distribute_layers(n_layers, a_swa, b_full) + + eff_W = min(context_length, sliding_window_size) + L = context_length + + # Per-layer costs + per_mha_full = kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_mha, bytes_per_elem) + per_gqa_full = kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_gqa, bytes_per_elem) + per_mha_swa = kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_mha, bytes_per_elem) + per_gqa_swa = kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_gqa, bytes_per_elem) + + # Totals + total_mha_allfull = per_mha_full * n_layers + total_gqa_allfull = per_gqa_full * n_layers + total_mixed_mha = n_swa_layers * per_mha_swa + n_full_layers * per_mha_full + total_mixed_gqa = n_swa_layers * per_gqa_swa + n_full_layers * per_gqa_full + + return { + "bytes_per_elem": bytes_per_elem, + "head_dim": head_dim, + "n_kv_heads_gqa": n_kv_heads_gqa, + "eff_W": eff_W, + "n_swa_layers": n_swa_layers, + "n_full_layers": n_full_layers, + "total_mha_allfull": total_mha_allfull, + "total_gqa_allfull": total_gqa_allfull, + "total_mixed_mha": total_mixed_mha, + "total_mixed_gqa": total_mixed_gqa, + } + + +def main(): + p = argparse.ArgumentParser(description="Estimate KV-cache memory for MHA/GQA with SWA layer ratio") + p.add_argument("--context_length", default=1024, type=int) + p.add_argument("--sliding_window_size", required=True, type=int, + help="SWA window size W per SWA layer.") + p.add_argument("--emb_dim", required=True, type=int) + p.add_argument("--n_heads", required=True, type=int) + p.add_argument("--n_layers", required=True, type=int) + p.add_argument("--n_kv_groups", required=True, type=int, + help="GQA groups; 1 means MHA-equivalent KV heads.") + p.add_argument("--batch_size", default=1, type=int) + p.add_argument("--dtype", choices=DTYPE_BYTES.keys(), default="fp16") + p.add_argument("--swa_ratio", default="1:0", + help="SWA:Full layer ratio. Example '5:1' -> 5 SWA for each 1 full. " + "'1:5' -> 1 SWA for 5 full. Default '1:0' = all SWA.") + args = p.parse_args() + + cfg = { + "context_length": args.context_length, + "sliding_window_size": args.sliding_window_size, + "emb_dim": args.emb_dim, + "n_heads": args.n_heads, + "n_layers": args.n_layers, + "n_kv_groups": args.n_kv_groups, + } + + res = estimate_totals( + context_length=cfg["context_length"], + sliding_window_size=cfg["sliding_window_size"], + emb_dim=cfg["emb_dim"], + n_heads=cfg["n_heads"], + n_layers=cfg["n_layers"], + n_kv_groups=cfg["n_kv_groups"], + batch_size=args.batch_size, + dtype=args.dtype, + swa_ratio=args.swa_ratio, + ) + + print("==== Config ====") + for k, v in cfg.items(): + print(f"{k:23}: {v}") + print(f"batch_size : {args.batch_size}") + print(f"dtype : {args.dtype} ({res['bytes_per_elem']} Bytes/elem)") + print(f"head_dim : {res['head_dim']}") + print(f"GQA n_kv_heads : {res['n_kv_heads_gqa']}") + print(f"Effective SWA window W : {res['eff_W']}") + print(f"Layer ratio (SWA:Full) : {args.swa_ratio} -> " + f"{res['n_swa_layers']} SWA, {res['n_full_layers']} Full") + print() + + print("==== KV-cache totals across all layers ====") + print(f"MHA KV total : {bytes_convert(res['total_mha_allfull'])}") + print(f"GQA KV total : {bytes_convert(res['total_gqa_allfull'])}") + print(f"MHA + SWA (ratio {args.swa_ratio}) : {bytes_convert(res['total_mixed_mha'])}") + print(f"GQA + SWA (ratio {args.swa_ratio}) : {bytes_convert(res['total_mixed_gqa'])}") + print() + + +if __name__ == "__main__": + main() diff --git a/ch04/06_swa/plot_memory_estimates_mla.py b/ch04/06_swa/plot_memory_estimates_mla.py new file mode 100644 index 000000000..e4c420880 --- /dev/null +++ b/ch04/06_swa/plot_memory_estimates_mla.py @@ -0,0 +1,90 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import matplotlib.pyplot as plt + +# Bytes per element +DTYPE_BYTES = { + "fp32": 4, + "bf16": 2, + "fp16": 2, + "fp8": 1, + "int8": 1, +} + + +def bytes_to_gb(n_bytes): + return n_bytes / (1000. ** 3) + + +def kv_bytes_total_mha(batch, context_length, emb_dim, n_heads, + n_layers, bytes_per_elem): + head_dim = emb_dim / n_heads + per_layer = batch * context_length * head_dim * n_heads * 2 * bytes_per_elem + return per_layer * n_layers + + +def kv_bytes_total_mla(batch, context_length, n_layers, latent_dim, bytes_per_elem): + return batch * context_length * n_layers * latent_dim * bytes_per_elem + + +def plot_abs_kv_vs_context_multiple(): + n_heads = 24 + emb_dim = 2048 + n_layers = 48 + batch_size = 1 + dtype = "bf16" + bytes_per_elem = DTYPE_BYTES[dtype] + + context_lengths = [ + 256, 512, 1024, 2048, 4096, 8192, + 16384, 32768, 65536, 131072 + ] + + mha_gb = [] + for L in context_lengths: + total_mha = kv_bytes_total_mha( + batch_size, L, emb_dim, n_heads, n_layers, bytes_per_elem + ) + mha_gb.append(bytes_to_gb(total_mha)) + + latent_dims = [1024, 512, 256, 64] + plt.figure() + plt.plot(context_lengths, mha_gb, marker="o", label="MHA (KV total)") + + L_ref = context_lengths[-1] + total_mha_ref = kv_bytes_total_mha(batch_size, L_ref, emb_dim, n_heads, n_layers, bytes_per_elem) + + for latent_dim in latent_dims: + mla_gb = [] + for L in context_lengths: + total_mla = kv_bytes_total_mla( + batch_size, L, n_layers, latent_dim, bytes_per_elem + ) + mla_gb.append(bytes_to_gb(total_mla)) + + total_mla_ref = kv_bytes_total_mla(batch_size, L_ref, n_layers, latent_dim, bytes_per_elem) + comp = total_mha_ref / total_mla_ref if total_mla_ref != 0 else float("inf") + + plt.plot(context_lengths, mla_gb, marker="o", + label=f"MLA (latent_dim={latent_dim}, {comp:,.1f}× compression)") + + plt.xscale("log") + plt.xlabel("context_length (log scale)") + plt.ylabel("Total KV cache (GB)") + plt.title( + "KV-cache vs Context Length — MHA vs MLA\n" + f"(n_heads={n_heads}, emb_dim={emb_dim}, n_layers={n_layers}, " + f"batch={batch_size}, dtype={dtype})", + fontsize=8 + ) + plt.grid(True, which="both") + plt.legend() + plt.tight_layout() + plt.savefig("kv_bytes_vs_context_length.pdf") + + +if __name__ == "__main__": + plot_abs_kv_vs_context_multiple() diff --git a/ch04/06_swa/plot_memory_estimates_swa.py b/ch04/06_swa/plot_memory_estimates_swa.py new file mode 100644 index 000000000..b75f0cf0a --- /dev/null +++ b/ch04/06_swa/plot_memory_estimates_swa.py @@ -0,0 +1,231 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch +# +# Sliding Window Attention (SWA) memory usage vs context length plot. +# +# This script mirrors the style and structure of plot_memory_estimates_mla.py. + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +# Bytes per element +DTYPE_BYTES = { + "fp32": 4, + "bf16": 2, + "fp16": 2, + "fp8": 1, + "int8": 1, +} + + +def bytes_to_gb(n_bytes): + return n_bytes / (1000.0 ** 3) + + +def parse_ratio(ratio_str): + # "--swa_ratio a:b" means a SWA layers for every b full layers within a block + try: + a_str, b_str = ratio_str.split(":") + a, b = int(a_str), int(b_str) + assert a >= 0 and b >= 0 and (a + b) > 0 + return a, b + except Exception: + raise ValueError("--swa_ratio must be in the form 'a:b' with nonnegative integers and a+b>0") + + +def kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem): + # For MHA, n_kv_heads = n_heads, which cancels out: + # total = B * L * E * 2 (K,V) * bytes * n_layers + return batch * context_length * emb_dim * 2 * bytes_per_elem * n_layers + + +def kv_bytes_total_gqa( + batch, context_length, emb_dim, n_layers, bytes_per_elem, n_kv_groups +): + # For GQA, n_kv_heads = n_heads / n_kv_groups + # => scale the MHA total by 1 / n_kv_groups + base = kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem) + return base / n_kv_groups + + +def kv_bytes_total_mha_swa( + batch, context_length, emb_dim, n_layers, bytes_per_elem, window, swa_ratio +): + # Split layers into SWA vs Full + a, b = parse_ratio(swa_ratio) + total_blocks = a + b + n_swa_layers = int(round(n_layers * (a / total_blocks))) + n_full_layers = n_layers - n_swa_layers + + total_full = kv_bytes_total_mha( + batch, context_length, emb_dim, n_full_layers, bytes_per_elem + ) + total_swa = kv_bytes_total_mha( + batch, window, emb_dim, n_swa_layers, bytes_per_elem + ) + return total_full + total_swa + + +def kv_bytes_total_gqa_swa( + batch, + context_length, + emb_dim, + n_layers, + bytes_per_elem, + n_kv_groups, + window, + swa_ratio, +): + a, b = parse_ratio(swa_ratio) + total_blocks = a + b + n_swa_layers = int(round(n_layers * (a / total_blocks))) + n_full_layers = n_layers - n_swa_layers + + total_full = kv_bytes_total_gqa( + batch, + context_length, + emb_dim, + n_full_layers, + bytes_per_elem, + n_kv_groups, + ) + total_swa = kv_bytes_total_gqa( + batch, window, emb_dim, n_swa_layers, bytes_per_elem, n_kv_groups + ) + return total_full + total_swa + + +def main(): + p = argparse.ArgumentParser( + description="KV-cache vs Context Length — MHA vs GQA with SWA overlays" + ) + p.add_argument("--emb_dim", type=int, required=True) + p.add_argument("--n_heads", type=int, required=True) + p.add_argument("--n_layers", type=int, required=True) + p.add_argument("--batch_size", type=int, default=1) + p.add_argument("--dtype", choices=DTYPE_BYTES.keys(), default="bf16") + p.add_argument( + "--sliding_window_size", type=int, required=True, help="SWA window size W" + ) + p.add_argument("--swa_ratio", type=str, default="5:1", help="SWA:Full ratio, e.g., 5:1") + p.add_argument( + "--output", type=Path, default=Path("kv_bytes_vs_context_length.pdf") + ) + args = p.parse_args() + + batch_size = args.batch_size + emb_dim = args.emb_dim + n_heads = args.n_heads + n_layers = args.n_layers + bytes_per_elem = DTYPE_BYTES[args.dtype] + + kv_groups = 4 + valid_g4 = (n_heads % kv_groups == 0) + + context_lengths = [ + 256, 512, 1024, 2048, 4096, 8192, + 16384, 32768, 65536, 131072 + ] + + series = { + "MHA (KV total)": [], + f"SWA on MHA (ratio {args.swa_ratio}, W={args.sliding_window_size})": [], + } + if valid_g4: + series["GQA kv_groups=4 (full)"] = [] + series[ + f"SWA on GQA kv_groups=4 (ratio {args.swa_ratio}, W={args.sliding_window_size})" + ] = [] + + for L in context_lengths: + total_mha = kv_bytes_total_mha( + batch_size, L, emb_dim, n_layers, bytes_per_elem + ) + total_mha_swa = kv_bytes_total_mha_swa( + batch_size, + L, + emb_dim, + n_layers, + bytes_per_elem, + window=args.sliding_window_size, + swa_ratio=args.swa_ratio, + ) + series["MHA (KV total)"].append(bytes_to_gb(total_mha)) + series[ + f"SWA on MHA (ratio {args.swa_ratio}, W={args.sliding_window_size})" + ].append(bytes_to_gb(total_mha_swa)) + + if valid_g4: + total_gqa = kv_bytes_total_gqa( + batch_size, L, emb_dim, n_layers, bytes_per_elem, n_kv_groups=kv_groups + ) + total_gqa_swa = kv_bytes_total_gqa_swa( + batch_size, + L, + emb_dim, + n_layers, + bytes_per_elem, + n_kv_groups=kv_groups, + window=args.sliding_window_size, + swa_ratio=args.swa_ratio, + ) + series["GQA kv_groups=4 (full)"].append(bytes_to_gb(total_gqa)) + series[ + f"SWA on GQA kv_groups=4 (ratio {args.swa_ratio}, W={args.sliding_window_size})" + ].append(bytes_to_gb(total_gqa_swa)) + + plt.figure(figsize=(10, 5)) + x = np.array(context_lengths, dtype=float) + + colors = { + "MHA": "#1f77b4", + "GQA": "#ff7f0e", + } + + for label, yvals in series.items(): + y = np.array(yvals, dtype=float) + if np.all(np.isnan(y)): + continue + + linestyle = "--" if "SWA" in label else "-" + if "MHA" in label: + color = colors["MHA"] + elif "GQA" in label: + color = colors["GQA"] + else: + color = None + + plt.plot(x, y, marker="o", label=label, linestyle=linestyle, color=color) + + plt.xscale("log") + plt.xlabel("context_length (log scale)") + plt.ylabel("Total KV cache (GB)") + plt.title( + "KV-cache vs Context Length — MHA vs GQA (SWA overlays)\n" + f"(n_heads={n_heads}, emb_dim={emb_dim}, n_layers={n_layers}, " + f"batch={batch_size}, dtype={args.dtype}; " + f"SWA ratio={args.swa_ratio}, W={args.sliding_window_size})", + fontsize=8, + ) + plt.grid(True, which="both") + plt.legend() + plt.tight_layout() + plt.savefig(args.output) + plt.close() + + if not valid_g4: + print( + f"Skipped GQA kv_groups=4 because n_heads={args.n_heads} " + "is not divisible by 4." + ) + print(f"Saved plot to: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/ch04/README.md b/ch04/README.md index 0a95ce14f..e058e9a66 100644 --- a/ch04/README.md +++ b/ch04/README.md @@ -13,6 +13,7 @@ - [ch05/07_gpt_to_llama](../ch05/07_gpt_to_llama) contains a step-by-step guide for converting a GPT architecture implementation to Llama 3.2 and loads pretrained weights from Meta AI (it might be interesting to look at alternative architectures after completing chapter 4, but you can also save that for after reading chapter 5) - [04_gqa](04_gqa) contains an introduction to Grouped-Query Attention (GQA), which is used by most modern LLMs (Llama 4, gpt-oss, Qwen3, Gemma 3, and many more) as alternative to regular Multi-Head Attention (MHA) - [05_mla](05_mla) contains an introduction to Multi-Head Latent Attention (MLA), which is used by DeepSeek V3, as alternative to regular Multi-Head Attention (MHA) +- [06_swa](06_swa) contains an introduction to Sliding Window Attention (SWA), which is used by Gemma 3 and others