Skip to content

(Unofficial) PyTorch implementation of grouped-query attention (GQA) from "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (https://arxiv.org/pdf/2305.13245.pdf)

License

Notifications You must be signed in to change notification settings

fkodom/grouped-query-attention-pytorch

Repository files navigation

grouped-query-attention-pytorch

(Unofficial) PyTorch implementation of grouped-query attention (GQA) from GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

compare-attention-mechanisms

Includes:

To do:

  • Fine-tuning code for T5 GQA models
  • Reproduce fine-tuning results from GQA paper, figures 3,5

Install

PyPI: (NOT YET AVAILABLE)

pip install grouped-query-attention-pytorch

From source:

pip install "grouped-query-attention-pytorch @ git+ssh://[email protected]/fkodom/grouped-query-attention-pytorch.git"

For contributors:

# Install all dev dependencies (tests, T5 support, etc.)
pip install "grouped-query-attention-pytorch[test,t5] @ git+ssh://[email protected]/fkodom/grouped-query-attention-pytorch.git"
# Setup pre-commit hooks
pre-commit install

Benchmark

I attempt to reproduce the runtime benchmarks from the GQA paper (Figure 6). Unfortunately, I don't have access to the same hardware, so the comparison isn't perfect. (They use multiple high-end GPUs, and I use a single 2080 Ti.) Even with different hardware, though, it is clear that runtime scales similarly with the number of GQA groups.

For more details, see scripts/README.md

Left: This repo
Right: Original paper

drawing drawing

Usage

scaled_dot_product_gqa

See: attention.py

Intended to be a drop-in replacement for F.scaled_dot_product_attention with support for GQA.

NOTE: The built-in F.scaled_dot_product_attention will be much faster when you're not using grouped queries -- especially for torch>=2.0, which uses flash attention under the hood. However, this benchmark shows that naie scaled_dot_product_gqa is faster than flash attention when the number of GQA groups is small. 🔥

import torch

from grouped_query_attention_pytorch.attention import scaled_dot_product_gqa

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 128, 2, 64, device="cuda", dtype=torch.float16)
value = torch.randn(1, 128, 2, 64, device="cuda", dtype=torch.float16)

out, attn_weights = scaled_dot_product_gqa(
    query,
    key,
    value,
    is_causal=True,  # default: False
    need_weights=True,  # default: False, which returns 'attn_weights=None'
)
print(out.shape)  # (batch_size, q_seq_len, kv_heads, embed_dim)
# torch.Size([1, 256, 2, 64])
print(attn_weights.shape)  # (batch_size, q_seq_len, kv_seq_len, kv_heads)
# torch.Size([1, 256, 128, 2])

MultiheadGQA

See: attention.py

Intended to be a drop-in replacement for nn.MultiheadAttention with support for GQA.

NOTE: The same performance advice from scaled_dot_product_gqa (above) applies here as well.

from grouped_query_attention_pytorch.attention import MultiheadGQA

mha = MultiheadGQA(
    embed_dim=512, query_heads=8, kv_heads=2, device="cuda", dtype=torch.float16
)

# shapes: (batch_size, seq_len, embed_dim)
query = torch.randn(1, 256, 512, device="cuda", dtype=torch.float16)
key = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16)
value = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16)

out, attn_weights = mha(
    query,
    key,
    value,
    is_causal=True, # default: False
    need_weights=True, # default: False, which returns 'attn_weights=None'
)
print(out.shape)  # (batch_size, q_seq_len, embed_dim)
# torch.Size([1, 256, 512])
print(attn_weights.shape)  # (batch_size, q_seq_len, kv_seq_len, kv_heads)
# torch.Size([1, 256, 128, 2])

T5

See: t5.py

Convert a pretrained T5 model from huggingface/transformers to use GQA. The resulting model can be used and trained with the Huggingface Transformers library, just like an ordinary T5 model.

from transformers import T5ForConditionalGeneration, T5Tokenizer

from grouped_query_attention_pytorch.t5 import convert_t5_to_gqa

# Initialize a pre-trained T5 model
t5 = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)
# Convert attention layers to GQA
t5_gqa = convert_t5_to_gqa(t5, kv_heads=2, inplace=False)  # default: inplace=False

# Generate some text with the converted model
input_ids = tokenizer(
    "translate English to German: The house is wonderful.", return_tensors="pt"
).input_ids
outputs = t5_gqa.generate(input_ids, max_new_tokens=25)
text = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
print(text)
# The correct answer is:  ['<pad>', 'Das', 'Haus', 'ist', 'wunderbar', '.', '</s>']
# NOTE: The original T5 model produces this answer, and so does GQA when we use the
# maximum number of KV heads (kv_heads=8 in this example), which effectively makes
# GQA equivalent to the original T5 model with MHA.  The text quickly degrades as
# we reduce the number of heads.

GQATransformer

I also provide a prototype implementation of an (untrained) encoder-decoder Transformer model, which uses GQA instead of MHA. This is mostly for reference/educational purposes, but in principle it could be used as a drop-in replacement for nn.Transformer.

See: transformer.py

from grouped_query_attention_pytorch.transformer import GQATransformer, GQATransformerLM

device = torch.device("cuda")
dtype = torch.float16

net = GQATransformer(
    d_model=512,  # required
    nhead=8,  # required
    kv_heads=2,  # required
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048,
    dropout=0.1,
    activation="relu",
    layer_norm_eps=1e-5,
    device=device,
    dtype=dtype,
)
# shape: (batch_size, seq_len, d_model)
x = torch.randn(1, 256, 512, device=device, dtype=dtype)
with torch.no_grad():
    y = net.forward(x, is_causal=True)  # default: is_causal=True
print(y.shape)
# torch.Size([1, 256, 512])

num_tokens = 10000  # usually obtained from the tokenizer
lm = GQATransformerLM(
    num_tokens=num_tokens,  # required
    d_model=512,  # required
    nhead=8,  # required
    kv_heads=2,  # required
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048,
    dropout=0.1,
    activation="relu",
    layer_norm_eps=1e-5,
    device=device,
    dtype=dtype,
)
# shape: (batch_size, seq_len)
x = torch.randint(0, num_tokens, (1, 256), device=device, dtype=torch.long)
with torch.no_grad():
    y = lm.forward(x, is_causal=True)  # default: is_causal=True
print(y.shape)
# torch.Size([1, 256, num_tokens])

About

(Unofficial) PyTorch implementation of grouped-query attention (GQA) from "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (https://arxiv.org/pdf/2305.13245.pdf)

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages