From 5222c35b2e020b5b1245742307156a839bda2a38 Mon Sep 17 00:00:00 2001 From: Pete Date: Wed, 8 Mar 2023 15:56:56 -0800 Subject: [PATCH] Add triton implementation of FlashAttention (#24) --- .github/workflows/main.yml | 10 + CHANGELOG.md | 5 +- Makefile | 15 +- docker/Dockerfile.base | 15 ++ Dockerfile.gantry => docker/Dockerfile.gantry | 2 +- Dockerfile.test => docker/Dockerfile.test | 2 +- dolma/config.py | 9 +- dolma/model.py | 118 ++++++++++- requirements.txt | 4 + tests/model_test.py | 196 +++++++++++++++--- 10 files changed, 334 insertions(+), 42 deletions(-) create mode 100644 docker/Dockerfile.base rename Dockerfile.gantry => docker/Dockerfile.gantry (86%) rename Dockerfile.test => docker/Dockerfile.test (87%) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 47bd6d223..f963c7d2f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -136,6 +136,16 @@ jobs: priority: preemptible resources: gpuCount: 1 + constraints: + cluster: + - ai2/general-cirrascale + - ai2/general-cirrascale-a100-80g-ib + - ai2/allennlp-cirrascale + - ai2/aristo-cirrascale + - ai2/mosaic-cirrascale + - ai2/mosaic-cirrascale-a100 + - ai2/prior-cirrascale + - ai2/s2-cirrascale envVars: - name: COMMIT_SHA value: ${{ env.COMMIT_SHA }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 82a3a6ea9..7051173b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,4 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added GPT-based model, tokenizer, data pipeline, and `composer` training script. +- GPT-based model. +- Tokenizer and data pre-processing pipeline. +- `composer` training script. +- Triton-based FlashAttention. diff --git a/Makefile b/Makefile index e496a9224..07393fedb 100644 --- a/Makefile +++ b/Makefile @@ -21,16 +21,23 @@ beaker-info : @echo "Gantry image: $(GANTRY_IMAGE)" @echo "Testing image: $(TEST_IMAGE)" +.PHONY : images +images : gantry-image test-image + +PHONY : base-image +base-image : + docker build -f docker/Dockerfile.base -t $(IMAGE_NAME_BASE)-base . + .PHONY : gantry-image -gantry-image : - docker build -f Dockerfile.gantry -t $(IMAGE_NAME_BASE)-gantry . +gantry-image : base-image + docker build -f docker/Dockerfile.gantry -t $(IMAGE_NAME_BASE)-gantry . beaker image create $(IMAGE_NAME_BASE)-gantry --name $(IMAGE_NAME_BASE)-gantry-tmp --workspace $(BEAKER_WORKSPACE) beaker image delete $(GANTRY_IMAGE) || true beaker image rename $(BEAKER_USER)/$(IMAGE_NAME_BASE)-gantry-tmp $(IMAGE_NAME_BASE)-gantry .PHONY : test-image -test-image : - docker build -f Dockerfile.test -t $(IMAGE_NAME_BASE)-test . +test-image : base-image + docker build -f docker/Dockerfile.test -t $(IMAGE_NAME_BASE)-test . beaker image create $(IMAGE_NAME_BASE)-test --name $(IMAGE_NAME_BASE)-test-tmp --workspace $(BEAKER_WORKSPACE) beaker image delete $(TEST_IMAGE) || true beaker image rename $(BEAKER_USER)/$(IMAGE_NAME_BASE)-test-tmp $(IMAGE_NAME_BASE)-test diff --git a/docker/Dockerfile.base b/docker/Dockerfile.base new file mode 100644 index 000000000..83411da3f --- /dev/null +++ b/docker/Dockerfile.base @@ -0,0 +1,15 @@ +# Defines a CUDA-enabled Docker image suitable for installing all dependencies +# to this project. + +FROM ghcr.io/allenai/pytorch:1.13.1-cuda11.7-python3.10 + +# We need cuda dev for the old version of triton. +# NOTE: once we're able to upgrade triton to >=2.0, we can remove this. +RUN /opt/conda/bin/conda install -c nvidia cuda-libraries-dev + +# Install flash attn (and triton dependency) from our pre-built wheel. +RUN /opt/conda/bin/pip install --no-cache-dir \ + triton==2.0.0.dev20221202 \ + https://storage.googleapis.com/ai2-python-wheels/flash_attn/flash_attn-0.2.8%2Bcu117torch1.13.1-cp310-cp310-linux_x86_64.whl + +ENV CUDA_HOME=/opt/conda diff --git a/Dockerfile.gantry b/docker/Dockerfile.gantry similarity index 86% rename from Dockerfile.gantry rename to docker/Dockerfile.gantry index c4ec30402..1387ebe66 100644 --- a/Dockerfile.gantry +++ b/docker/Dockerfile.gantry @@ -4,7 +4,7 @@ # To build and push the image to Beaker, run 'make gantry-image'. # To test the image after pushing to Beaker, run 'make gantry-test'. -FROM ghcr.io/allenai/pytorch:1.13.1-cuda11.7-python3.10 +FROM dolma-base WORKDIR /stage diff --git a/Dockerfile.test b/docker/Dockerfile.test similarity index 87% rename from Dockerfile.test rename to docker/Dockerfile.test index 7cf8adc36..eb301a845 100644 --- a/Dockerfile.test +++ b/docker/Dockerfile.test @@ -4,7 +4,7 @@ # # To build and push the image to Beaker, run 'make test-image'. -FROM ghcr.io/allenai/pytorch:1.13.1-cuda11.7-python3.10 +FROM dolma-base COPY scripts/test_entrypoint.sh /entrypoint.sh RUN chmod +x /entrypoint.sh diff --git a/dolma/config.py b/dolma/config.py index 13d1206d8..a6bcab722 100644 --- a/dolma/config.py +++ b/dolma/config.py @@ -101,12 +101,12 @@ class ModelConfig(BaseConfig): mlp_ratio: int = 4 """ - The ratio of the inner MLP dimensionality to `d_model`. + The ratio of the inner MLP dimensionality to ``d_model``. """ alibi: bool = False """ - If `True`, use ALiBi embeddings. + If ``True``, use ALiBi embeddings. """ alibi_bias_max: float = 8.0 @@ -114,6 +114,11 @@ class ModelConfig(BaseConfig): Maximum absolute value of ALiBi bias. """ + flash_attention: bool = False + """ + If ``True``, use ``FlashAttention``. + """ + attention_dropout: float = 0.1 """ The dropout probability within the attention modules. diff --git a/dolma/model.py b/dolma/model.py index 84003bf2a..823208e2b 100644 --- a/dolma/model.py +++ b/dolma/model.py @@ -5,30 +5,40 @@ """ import math +from abc import abstractmethod from typing import NamedTuple, Optional, cast import torch import torch.nn as nn import torch.nn.functional as F +from einops import rearrange from .config import ModelConfig -__all__ = ["SelfAttention", "GPTMLP", "GPTBlock", "DolmaGPT"] +__all__ = ["TorchAttention", "GPTMLP", "GPTBlock", "DolmaGPT"] -class SelfAttention(nn.Module): +class DolmaAttentionBase(nn.Module): def __init__(self, config: ModelConfig): super().__init__() assert config.d_model % config.n_heads == 0 self.n_heads = config.n_heads self.d_model = config.d_model + # key, query, value projections for all heads, but in a batch self.c_attn = nn.Linear(config.d_model, 3 * config.d_model, device=config.init_device) + # for param init fn + self.c_attn._fused = (0, (self.d_model, 2 * self.d_model)) # type: ignore + # output projection self.c_proj = nn.Linear(config.d_model, config.d_model, device=config.init_device) + # for param init fn + self.c_proj._is_residual = True # type: ignore + # regularization self.attn_dropout = nn.Dropout(config.attention_dropout) self.resid_dropout = nn.Dropout(config.residual_dropout) + # optional layer norm for keys and queries. self.k_ln: Optional[nn.LayerNorm] = None self.q_ln: Optional[nn.LayerNorm] = None @@ -36,6 +46,19 @@ def __init__(self, config: ModelConfig): self.k_ln = nn.LayerNorm(self.d_model, device=config.init_device) self.q_ln = nn.LayerNorm(self.d_model, device=config.init_device) + @abstractmethod + def forward( + self, + x: torch.FloatTensor, + attention_bias: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + raise NotImplementedError + + +class TorchAttention(DolmaAttentionBase): + def __init__(self, config: ModelConfig): + super().__init__(config) + def forward( self, x: torch.FloatTensor, @@ -55,8 +78,9 @@ def forward( # Optionally apply layer norm to keys and queries. if self.k_ln is not None and self.q_ln is not None: - k = self.k_ln(k) - q = self.q_ln(q) + dtype = k.dtype + k = self.k_ln(k).to(dtype=dtype) + q = self.q_ln(q).to(dtype=dtype) # Move head forward to be next to the batch dim. # shape (all): (B, nh, T, hs) @@ -87,6 +111,55 @@ def forward( return y +class FlashAttention(DolmaAttentionBase): + """ + Triton implementation of FlashAttention. + """ + + def __init__(self, config: ModelConfig): + from flash_attn import flash_attn_triton # type: ignore + + super().__init__(config) + + assert self.d_model / self.n_heads in {64, 128}, "FlashAttention requires head dim of 64 or 128 for now" + assert config.attention_dropout == 0, "FlashAttention does not support attention dropout for now" + self.flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func + + def forward( + self, x: torch.FloatTensor, attention_bias: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + """ + :param x: A tensor of shape `(batch_size, seq_len, d_model)`. + :param attention_bias: A tensor of shape `(batch_size, n_heads, seq_len, seq_len)` + or an equivalently broadcastable shape. This is used to introduce causal or other biases + and it is simply added to the attention scores before the softmax. + """ + # Calculate query, key, values for all heads in batch. + # shape: (batch_size, seq_length, d_model * 3) + qkv = self.c_attn(x) + + # Optionally apply layer norm to keys and queries. + if self.q_ln is not None and self.k_ln is not None: + # Applying layernorm to qk + dtype = qkv.dtype + q, k, v = qkv.split(self.d_model, dim=-1) + q = self.q_ln(q).to(dtype=dtype) + k = self.k_ln(k).to(dtype=dtype) + qkv = torch.cat([q, k, v], dim=-1) + + # Apply inner attention function. + qkv = rearrange(qkv, "b s (t h d) -> b s t h d", t=3, h=self.n_heads) + y = self.flash_attn_qkvpacked_func(qkv, attention_bias) + + # Re-assemble all head outputs side by side. + y = rearrange(y, "b s h d -> b s (h d)") + + # Apply output projection. + y = self.resid_dropout(self.c_proj(y)) + + return y + + class GPTMLP(nn.Module): def __init__(self, config: ModelConfig): super().__init__() @@ -103,8 +176,11 @@ def forward(self, x): class GPTBlock(nn.Module): def __init__(self, config: ModelConfig): super().__init__() + self.config = config self.ln_1 = nn.LayerNorm(config.d_model, device=config.init_device) - self.attn = SelfAttention(config) + self.attn: DolmaAttentionBase = ( + FlashAttention(config) if config.flash_attention else TorchAttention(config) + ) self.ln_2 = nn.LayerNorm(config.d_model, device=config.init_device) self.mlp = GPTMLP(config) @@ -357,16 +433,40 @@ def param_init_fn(self, module): init_fn = partial(torch.nn.init.normal_, mean=0.0, std=self.config.init_std) + def fused_init_fn(module): + # Parameter initialization is often based on the parameters shape. + # If a layer is fused, initialization should be based on the shapes + # of the original tensor instead of the shape of the fused tensor. + # Layers which are fused should have the _fused attribute defined. + # The first element of _fused is the dimension along which the tensor is fused. + # This is followed by an iterable of split indices. + _fused = getattr(module, "_fused", None) + if _fused is None: + raise RuntimeError("Internal logic error") + + dim, splits = _fused + splits = (0, *splits, module.weight.size(dim)) + for s, e in zip(splits[:-1], splits[1:]): + slice_indices = [slice(None)] * module.weight.ndim + slice_indices[dim] = slice(s, e) + init_fn(module.weight[slice_indices]) + # Linear if isinstance(module, nn.Linear): - init_fn(module.weight) + if hasattr(module, "_fused"): + fused_init_fn(module) + else: + init_fn(module.weight) + if module.bias is not None: torch.nn.init.zeros_(module.bias) if getattr(module, "_is_residual", False): - module.weight.data.normal_( - mean=0.0, std=(self.config.init_std / math.sqrt(2 * self.config.n_layers)) - ) + with torch.no_grad(): + module.weight.div_(math.sqrt(2 * self.config.n_layers)) + + if module.bias is not None: + torch.nn.init.zeros_(module.bias) # Embedding if isinstance(module, nn.Embedding): diff --git a/requirements.txt b/requirements.txt index 23c0e2d78..61e8555ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ # Docker images. See each Dockerfile for details on how to do that. numpy torch +einops # bug with 0.13.0, see https://github.com/mosaicml/composer/issues/2030 mosaicml!=0.13.0 torchmetrics @@ -12,3 +13,6 @@ cached-path beaker-gantry omegaconf wandb +# Can't install these on a CPU-only environment: +# triton +# flash-attn diff --git a/tests/model_test.py b/tests/model_test.py index b2bd2bb8c..2f8a49442 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -7,36 +7,95 @@ @pytest.mark.parametrize( - "alibi, cuda", + "alibi, flash_attn, cuda, dtype", [ - pytest.param(True, False, id="alibi-emb-cpu"), - pytest.param(False, False, id="posit-emb-cpu"), + pytest.param(True, False, False, torch.bfloat16, id="alibi-emb-cpu-bf16"), + pytest.param(False, False, False, torch.bfloat16, id="posit-emb-cpu-bf16"), + pytest.param(True, False, False, torch.float32, id="alibi-emb-cpu-f32"), + pytest.param(False, False, False, torch.float32, id="posit-emb-cpu-f32"), pytest.param( + True, + False, + True, + torch.bfloat16, + id="alibi-emb-cuda-bf16", + marks=( + pytest.mark.gpu, + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), + ), + ), + pytest.param( + False, + False, + True, + torch.bfloat16, + id="posit-emb-cuda-bf16", + marks=( + pytest.mark.gpu, + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), + ), + ), + pytest.param( + True, + True, + True, + torch.bfloat16, + id="alibi-emb-flash-cuda-bf16", + marks=( + pytest.mark.gpu, + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), + ), + ), + pytest.param( + False, + True, + True, + torch.bfloat16, + id="posit-emb-flash-cuda-bf16", + marks=( + pytest.mark.gpu, + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), + ), + ), + pytest.param( + True, True, True, - id="alibi-emb-cuda", + torch.float16, + id="alibi-emb-flash-cuda-f16", marks=( pytest.mark.gpu, - pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA devices"), + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), ), ), pytest.param( False, True, - id="posit-emb-cuda", + True, + torch.float16, + id="posit-emb-flash-cuda-f16", marks=( pytest.mark.gpu, - pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA devices"), + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), ), ), ], ) -def test_forward(train_config: TrainConfig, tokenizer: Tokenizer, alibi: bool, cuda: bool): +def test_forward( + train_config: TrainConfig, tokenizer: Tokenizer, alibi: bool, flash_attn: bool, cuda: bool, dtype +): torch.manual_seed(0) train_config.model.alibi = alibi + train_config.model.flash_attention = flash_attn + if flash_attn: + train_config.model.attention_dropout = 0.0 if cuda: train_config.model.init_device = "cuda" + else: + train_config.model.init_device = "cpu" + + use_amp = dtype in {torch.float16, torch.bfloat16} model = DolmaGPT(train_config.model).eval() @@ -52,34 +111,123 @@ def test_forward(train_config: TrainConfig, tokenizer: Tokenizer, alibi: bool, c k: v.to(device=train_config.device) if isinstance(v, torch.Tensor) else v for k, v in batch_inputs.items() } - # Check that logits from individual inputs are equal to logits from batch. + # Run forward pass. with torch.inference_mode(): - output1 = model(torch.tensor(input1, device=train_config.device).unsqueeze(0)) - output2 = model(torch.tensor(input2, device=train_config.device).unsqueeze(0)) - batch_output = model(**batch_inputs) + with torch.autocast( + device_type="cuda" if cuda else "cpu", enabled=use_amp, dtype=None if not use_amp else dtype + ): + output1 = model(torch.tensor(input1, device=train_config.device).unsqueeze(0)) + output2 = model(torch.tensor(input2, device=train_config.device).unsqueeze(0)) + batch_output = model(**batch_inputs) - torch.testing.assert_close(output1.logits[0][: len(input1)], batch_output.logits[0][: len(input1)]) - torch.testing.assert_close(output2.logits[0][: len(input2)], batch_output.logits[1][: len(input2)]) + # Check that logits from individual inputs are equal to logits from batch. + # With using half-precision types these might have some big differences in a small + # percentage of the elements. + atol = 1e-2 if use_amp else None + rtol = 1e3 if use_amp else None + torch.testing.assert_close( + output1.logits[0][: len(input1)], batch_output.logits[0][: len(input1)], rtol=rtol, atol=atol + ) + torch.testing.assert_close( + output2.logits[0][: len(input2)], batch_output.logits[1][: len(input2)], rtol=rtol, atol=atol + ) -@pytest.mark.parametrize("alibi", [pytest.param(True, id="alibi-emb"), pytest.param(False, id="posit-emb")]) -def test_backward(train_config: TrainConfig, tokenizer: Tokenizer, alibi: bool): +@pytest.mark.parametrize( + "alibi, flash_attn, cuda, dtype", + [ + pytest.param(True, False, False, torch.bfloat16, id="alibi-emb-cpu-bf16"), + pytest.param(False, False, False, torch.bfloat16, id="posit-emb-cpu-bf16"), + pytest.param( + True, + False, + True, + torch.bfloat16, + id="alibi-emb-cuda-bf16", + marks=( + pytest.mark.gpu, + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), + ), + ), + pytest.param( + False, + False, + True, + torch.bfloat16, + id="posit-emb-cuda-bf16", + marks=( + pytest.mark.gpu, + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), + ), + ), + pytest.param( + True, + True, + True, + torch.bfloat16, + id="alibi-emb-flash-cuda-bf16", + marks=( + pytest.mark.gpu, + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), + pytest.mark.skipif( + torch.cuda.device_count() < 1 or "A100" not in torch.cuda.get_device_name(), + reason="Requires A100 GPU type", + ), + ), + ), + pytest.param( + False, + True, + True, + torch.bfloat16, + id="posit-emb-flash-cuda-bf16", + marks=( + pytest.mark.gpu, + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), + pytest.mark.skipif( + torch.cuda.device_count() < 1 or "A100" not in torch.cuda.get_device_name(), + reason="Requires A100 GPU type", + ), + ), + ), + ], +) +def test_backward( + train_config: TrainConfig, tokenizer: Tokenizer, alibi: bool, flash_attn: bool, cuda: bool, dtype +): torch.manual_seed(0) + use_amp = dtype in {torch.float16, torch.bfloat16} + scaler = None if not (cuda and use_amp) else torch.cuda.amp.GradScaler() + train_config.model.alibi = alibi + train_config.model.flash_attention = flash_attn + if flash_attn: + train_config.model.attention_dropout = 0.0 + if cuda: + train_config.model.init_device = "cuda" + else: + train_config.model.init_device = "cpu" + model = DolmaGPT(train_config.model).train() - # Forward pass to get logits. - input_ids = torch.tensor(tokenizer.encode("My name is DOLMA!"), device=train_config.device).unsqueeze(0) - logits = model(input_ids).logits + with torch.autocast( + device_type="cuda" if cuda else "cpu", enabled=use_amp, dtype=None if not use_amp else dtype + ): + # Forward pass to get logits. + input_ids = torch.tensor(tokenizer.encode("My name is DOLMA!"), device=train_config.device).unsqueeze(0) + logits = model(input_ids).logits - # Compute loss. - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = input_ids[..., 1:].contiguous() - loss = CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + # Compute loss. + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = input_ids[..., 1:].contiguous() + loss = CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # Backward pass. - loss.backward() + if scaler is not None: + scaler.scale(loss).backward() # type: ignore + else: + loss.backward() # Check gradients. for name, parameter in model.named_parameters():