Skip to content

Commit

Permalink
Add triton implementation of FlashAttention (allenai#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Mar 8, 2023
1 parent 9430863 commit 5222c35
Show file tree
Hide file tree
Showing 10 changed files with 334 additions and 42 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ jobs:
priority: preemptible
resources:
gpuCount: 1
constraints:
cluster:
- ai2/general-cirrascale
- ai2/general-cirrascale-a100-80g-ib
- ai2/allennlp-cirrascale
- ai2/aristo-cirrascale
- ai2/mosaic-cirrascale
- ai2/mosaic-cirrascale-a100
- ai2/prior-cirrascale
- ai2/s2-cirrascale
envVars:
- name: COMMIT_SHA
value: ${{ env.COMMIT_SHA }}
Expand Down
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added GPT-based model, tokenizer, data pipeline, and `composer` training script.
- GPT-based model.
- Tokenizer and data pre-processing pipeline.
- `composer` training script.
- Triton-based FlashAttention.
15 changes: 11 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,23 @@ beaker-info :
@echo "Gantry image: $(GANTRY_IMAGE)"
@echo "Testing image: $(TEST_IMAGE)"

.PHONY : images
images : gantry-image test-image

PHONY : base-image
base-image :
docker build -f docker/Dockerfile.base -t $(IMAGE_NAME_BASE)-base .

.PHONY : gantry-image
gantry-image :
docker build -f Dockerfile.gantry -t $(IMAGE_NAME_BASE)-gantry .
gantry-image : base-image
docker build -f docker/Dockerfile.gantry -t $(IMAGE_NAME_BASE)-gantry .
beaker image create $(IMAGE_NAME_BASE)-gantry --name $(IMAGE_NAME_BASE)-gantry-tmp --workspace $(BEAKER_WORKSPACE)
beaker image delete $(GANTRY_IMAGE) || true
beaker image rename $(BEAKER_USER)/$(IMAGE_NAME_BASE)-gantry-tmp $(IMAGE_NAME_BASE)-gantry

.PHONY : test-image
test-image :
docker build -f Dockerfile.test -t $(IMAGE_NAME_BASE)-test .
test-image : base-image
docker build -f docker/Dockerfile.test -t $(IMAGE_NAME_BASE)-test .
beaker image create $(IMAGE_NAME_BASE)-test --name $(IMAGE_NAME_BASE)-test-tmp --workspace $(BEAKER_WORKSPACE)
beaker image delete $(TEST_IMAGE) || true
beaker image rename $(BEAKER_USER)/$(IMAGE_NAME_BASE)-test-tmp $(IMAGE_NAME_BASE)-test
Expand Down
15 changes: 15 additions & 0 deletions docker/Dockerfile.base
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Defines a CUDA-enabled Docker image suitable for installing all dependencies
# to this project.

FROM ghcr.io/allenai/pytorch:1.13.1-cuda11.7-python3.10

# We need cuda dev for the old version of triton.
# NOTE: once we're able to upgrade triton to >=2.0, we can remove this.
RUN /opt/conda/bin/conda install -c nvidia cuda-libraries-dev

# Install flash attn (and triton dependency) from our pre-built wheel.
RUN /opt/conda/bin/pip install --no-cache-dir \
triton==2.0.0.dev20221202 \
https://storage.googleapis.com/ai2-python-wheels/flash_attn/flash_attn-0.2.8%2Bcu117torch1.13.1-cp310-cp310-linux_x86_64.whl

ENV CUDA_HOME=/opt/conda
2 changes: 1 addition & 1 deletion Dockerfile.gantry → docker/Dockerfile.gantry
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# To build and push the image to Beaker, run 'make gantry-image'.
# To test the image after pushing to Beaker, run 'make gantry-test'.

FROM ghcr.io/allenai/pytorch:1.13.1-cuda11.7-python3.10
FROM dolma-base

WORKDIR /stage

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.test → docker/Dockerfile.test
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
# To build and push the image to Beaker, run 'make test-image'.

FROM ghcr.io/allenai/pytorch:1.13.1-cuda11.7-python3.10
FROM dolma-base

COPY scripts/test_entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
Expand Down
9 changes: 7 additions & 2 deletions dolma/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,24 @@ class ModelConfig(BaseConfig):

mlp_ratio: int = 4
"""
The ratio of the inner MLP dimensionality to `d_model`.
The ratio of the inner MLP dimensionality to ``d_model``.
"""

alibi: bool = False
"""
If `True`, use ALiBi embeddings.
If ``True``, use ALiBi embeddings.
"""

alibi_bias_max: float = 8.0
"""
Maximum absolute value of ALiBi bias.
"""

flash_attention: bool = False
"""
If ``True``, use ``FlashAttention``.
"""

attention_dropout: float = 0.1
"""
The dropout probability within the attention modules.
Expand Down
118 changes: 109 additions & 9 deletions dolma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,60 @@
"""

import math
from abc import abstractmethod
from typing import NamedTuple, Optional, cast

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from .config import ModelConfig

__all__ = ["SelfAttention", "GPTMLP", "GPTBlock", "DolmaGPT"]
__all__ = ["TorchAttention", "GPTMLP", "GPTBlock", "DolmaGPT"]


class SelfAttention(nn.Module):
class DolmaAttentionBase(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
assert config.d_model % config.n_heads == 0
self.n_heads = config.n_heads
self.d_model = config.d_model

# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.d_model, 3 * config.d_model, device=config.init_device)
# for param init fn
self.c_attn._fused = (0, (self.d_model, 2 * self.d_model)) # type: ignore

# output projection
self.c_proj = nn.Linear(config.d_model, config.d_model, device=config.init_device)
# for param init fn
self.c_proj._is_residual = True # type: ignore

# regularization
self.attn_dropout = nn.Dropout(config.attention_dropout)
self.resid_dropout = nn.Dropout(config.residual_dropout)

# optional layer norm for keys and queries.
self.k_ln: Optional[nn.LayerNorm] = None
self.q_ln: Optional[nn.LayerNorm] = None
if config.attention_layer_norm:
self.k_ln = nn.LayerNorm(self.d_model, device=config.init_device)
self.q_ln = nn.LayerNorm(self.d_model, device=config.init_device)

@abstractmethod
def forward(
self,
x: torch.FloatTensor,
attention_bias: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
raise NotImplementedError


class TorchAttention(DolmaAttentionBase):
def __init__(self, config: ModelConfig):
super().__init__(config)

def forward(
self,
x: torch.FloatTensor,
Expand All @@ -55,8 +78,9 @@ def forward(

# Optionally apply layer norm to keys and queries.
if self.k_ln is not None and self.q_ln is not None:
k = self.k_ln(k)
q = self.q_ln(q)
dtype = k.dtype
k = self.k_ln(k).to(dtype=dtype)
q = self.q_ln(q).to(dtype=dtype)

# Move head forward to be next to the batch dim.
# shape (all): (B, nh, T, hs)
Expand Down Expand Up @@ -87,6 +111,55 @@ def forward(
return y


class FlashAttention(DolmaAttentionBase):
"""
Triton implementation of FlashAttention.
"""

def __init__(self, config: ModelConfig):
from flash_attn import flash_attn_triton # type: ignore

super().__init__(config)

assert self.d_model / self.n_heads in {64, 128}, "FlashAttention requires head dim of 64 or 128 for now"
assert config.attention_dropout == 0, "FlashAttention does not support attention dropout for now"
self.flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func

def forward(
self, x: torch.FloatTensor, attention_bias: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
"""
:param x: A tensor of shape `(batch_size, seq_len, d_model)`.
:param attention_bias: A tensor of shape `(batch_size, n_heads, seq_len, seq_len)`
or an equivalently broadcastable shape. This is used to introduce causal or other biases
and it is simply added to the attention scores before the softmax.
"""
# Calculate query, key, values for all heads in batch.
# shape: (batch_size, seq_length, d_model * 3)
qkv = self.c_attn(x)

# Optionally apply layer norm to keys and queries.
if self.q_ln is not None and self.k_ln is not None:
# Applying layernorm to qk
dtype = qkv.dtype
q, k, v = qkv.split(self.d_model, dim=-1)
q = self.q_ln(q).to(dtype=dtype)
k = self.k_ln(k).to(dtype=dtype)
qkv = torch.cat([q, k, v], dim=-1)

# Apply inner attention function.
qkv = rearrange(qkv, "b s (t h d) -> b s t h d", t=3, h=self.n_heads)
y = self.flash_attn_qkvpacked_func(qkv, attention_bias)

# Re-assemble all head outputs side by side.
y = rearrange(y, "b s h d -> b s (h d)")

# Apply output projection.
y = self.resid_dropout(self.c_proj(y))

return y


class GPTMLP(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
Expand All @@ -103,8 +176,11 @@ def forward(self, x):
class GPTBlock(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.ln_1 = nn.LayerNorm(config.d_model, device=config.init_device)
self.attn = SelfAttention(config)
self.attn: DolmaAttentionBase = (
FlashAttention(config) if config.flash_attention else TorchAttention(config)
)
self.ln_2 = nn.LayerNorm(config.d_model, device=config.init_device)
self.mlp = GPTMLP(config)

Expand Down Expand Up @@ -357,16 +433,40 @@ def param_init_fn(self, module):

init_fn = partial(torch.nn.init.normal_, mean=0.0, std=self.config.init_std)

def fused_init_fn(module):
# Parameter initialization is often based on the parameters shape.
# If a layer is fused, initialization should be based on the shapes
# of the original tensor instead of the shape of the fused tensor.
# Layers which are fused should have the _fused attribute defined.
# The first element of _fused is the dimension along which the tensor is fused.
# This is followed by an iterable of split indices.
_fused = getattr(module, "_fused", None)
if _fused is None:
raise RuntimeError("Internal logic error")

dim, splits = _fused
splits = (0, *splits, module.weight.size(dim))
for s, e in zip(splits[:-1], splits[1:]):
slice_indices = [slice(None)] * module.weight.ndim
slice_indices[dim] = slice(s, e)
init_fn(module.weight[slice_indices])

# Linear
if isinstance(module, nn.Linear):
init_fn(module.weight)
if hasattr(module, "_fused"):
fused_init_fn(module)
else:
init_fn(module.weight)

if module.bias is not None:
torch.nn.init.zeros_(module.bias)

if getattr(module, "_is_residual", False):
module.weight.data.normal_(
mean=0.0, std=(self.config.init_std / math.sqrt(2 * self.config.n_layers))
)
with torch.no_grad():
module.weight.div_(math.sqrt(2 * self.config.n_layers))

if module.bias is not None:
torch.nn.init.zeros_(module.bias)

# Embedding
if isinstance(module, nn.Embedding):
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Docker images. See each Dockerfile for details on how to do that.
numpy
torch
einops
# bug with 0.13.0, see https://github.com/mosaicml/composer/issues/2030
mosaicml!=0.13.0
torchmetrics
Expand All @@ -12,3 +13,6 @@ cached-path
beaker-gantry
omegaconf
wandb
# Can't install these on a CPU-only environment:
# triton
# flash-attn
Loading

0 comments on commit 5222c35

Please sign in to comment.