Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/44 make flash attention configurable #47

Open
wants to merge 22 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
539e8a2
feat: FlashMultiHeadSelfAttention
theissenhelen Sep 17, 2024
3317138
Chore/multiple fixes ci precommit (#41)
theissenhelen Sep 18, 2024
3186a8e
11 add configurability to dropout in multiheadselfattention module (#12)
theissenhelen Sep 18, 2024
a86c9a8
chore!: drop support for scaled_dot_product_attention
theissenhelen Sep 20, 2024
105443f
feat: add softcap
theissenhelen Sep 20, 2024
e82a59e
test: add softcap
theissenhelen Sep 20, 2024
e648eb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2024
6271cd8
feat: flash attention lazy import
theissenhelen Sep 23, 2024
d4940e7
feat: make alibi slopes configurable
theissenhelen Sep 27, 2024
9ff6cb9
chore(deps): add flash-attn
theissenhelen Sep 27, 2024
bbd89dc
feat: use scaled_dot_product as default
theissenhelen Oct 1, 2024
91533c6
feat: make alibi_slope cinfigurable in block, chunk processor
theissenhelen Oct 1, 2024
0eb5c50
chore(deps): remove flash-attn
theissenhelen Oct 1, 2024
c04e641
feat: get alibi_slopes
theissenhelen Oct 2, 2024
6523b47
docs: update docstrings
theissenhelen Oct 3, 2024
22623cc
fix: bias shape
theissenhelen Oct 3, 2024
ed07e34
fix: softcap optional
theissenhelen Oct 3, 2024
c841324
fix: import annotations from future
theissenhelen Oct 3, 2024
6c12dda
fix: annotation error
theissenhelen Oct 3, 2024
b7b8f2e
docs: update changelog
theissenhelen Oct 3, 2024
df353d9
fix: type annotation
theissenhelen Oct 7, 2024
fc335c7
feat: catch low flash-attn version
theissenhelen Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,24 @@ Keep it human-readable, your future self will thank you!

### Added

- CI workflow to update the changelog on release
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
- add configurability of flash attention (#47)
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- CI workflow to update the changelog on release
- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords.
- Codeowners file
- Pygrep precommit hooks
- Docsig precommit hooks
- Changelog merge strategy


### Changed

- Update CI to inherit from common infrastructue reusable workflows
- run downstream-ci only when src and tests folders have changed
- New error messages for wrongs graphs.
- Feature: Change model to be instantiatable in the interface, addressing [#28](https://github.com/ecmwf/anemoi-models/issues/28) through [#45](https://github.com/ecmwf/anemoi-models/pulls/45)
- Bugfixes for CI

### Removed

Expand Down
111 changes: 94 additions & 17 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,27 @@
# nor does it submit to any jurisdiction.
#

from __future__ import annotations

import logging
import math
from typing import Optional

import einops
import torch
from packaging import version
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup

try:
from flash_attn import flash_attn_func as attn_func
except ImportError:
from torch.nn.functional import scaled_dot_product_attention as attn_func

_FLASH_ATTENTION_AVAILABLE = False
else:
_FLASH_ATTENTION_AVAILABLE = True

from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence

LOGGER = logging.getLogger(__name__)


class MultiHeadSelfAttention(nn.Module):
"""Multi Head Self Attention Pytorch Layer."""
"""Multi Head Self Attention Pytorch Layer using flash attention, see https://github.com/Dao-AILab/flash-attention"""

def __init__(
self,
Expand All @@ -41,31 +37,77 @@ def __init__(
is_causal: bool = False,
window_size: Optional[int] = None,
dropout_p: float = 0.0,
use_flash_attention: bool = False,
softcap: float = None,
use_alibi_slopes: bool = None,
):
"""Initialize MultiHeadSelfAttention.

Parameters
----------
num_heads : int
number of heads
embed_dim : int
embedding dimension
bias : bool, optional
bias, by default False
is_causal : bool, optional
apply causal attention mask, by default False
window_size : Optional[int], optional
window_size, by default None
dropout_p : float, optional
dropout probability, by default 0.0
softcap : float, optional
Anything > 0 activates softcapping attention, by default None
use_alibi_slopes : bool, optional
Adds bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
to the attention score of query i and key j, where alibi_slope
is calculated using get_alibi_slopes, by default None
"""
super().__init__()

assert (
embed_dim % num_heads == 0
), f"Embedding dimension ({embed_dim}) must be divisible by number of heads ({num_heads})"

self.use_flash_attention = use_flash_attention
self.set_attention_function()

self.num_heads = num_heads
self.embed_dim = embed_dim
self.head_dim = embed_dim // num_heads # q k v
self.window_size = (window_size, window_size) # flash attention
self.dropout_p = dropout_p
self.is_causal = is_causal
self.softcap = softcap
self.use_alibi_slopes = use_alibi_slopes

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func
if self.use_alibi_slopes is not None:
self.alibi_slopes = get_alibi_slopes(num_heads)
assert self.alibi_slopes.shape[0] == num_heads

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")
self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)

self.projection = nn.Linear(embed_dim, embed_dim, bias=True)

def set_attention_function(self):

if self.use_flash_attention:
import flash_attn

if version.parse(flash_attn.__version__) < version.parse("2.6.0"):
raise SystemExit("Error: Flash-attn version is too low. Update to 2.6.0 or higher.")
else:
self.attention = flash_attn.flash_attn_func
else:
from torch.nn.functional import scaled_dot_product_attention

self.attention = scaled_dot_product_attention

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:

query, key, value = self.lin_qkv(x).chunk(3, -1)

if model_comm_group:
Expand All @@ -88,11 +130,23 @@ def forward(
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
dropout_p = self.dropout_p if self.training else 0.0

if _FLASH_ATTENTION_AVAILABLE:
if self.use_flash_attention:
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)
out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p)

alibi_slopes = self.alibi_slopes.repeat(batch_size, 1).to(query.device) if self.use_alibi_slopes else None

out = self.attention(
query,
key,
value,
causal=False,
window_size=self.window_size,
dropout_p=dropout_p,
softcap=self.softcap,
alibi_slopes=alibi_slopes,
)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
else:
out = self.attention(
Expand All @@ -101,11 +155,34 @@ def forward(
value,
is_causal=False,
dropout_p=dropout_p,
) # expects (batch heads grid variable) format
)

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")

out = self.projection(out)

return out


def get_alibi_slopes(num_heads: int) -> Tensor:
"""Calculates linearly decreasing slopes for alibi attention.

Parameters
----------
num_heads : int
number of attention heads

Returns
-------
Tensor
aLiBi slopes
"""
n = 2 ** math.floor(math.log2(num_heads))
slope_0 = 2.0 ** (-8.0 / n)
alibi_slopes = torch.pow(slope_0, torch.arange(1, 1 + n))
if n < num_heads:
slope_hat_0 = 2.0 ** (-4.0 / n)
alibi_slopes_hat = torch.pow(slope_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
alibi_slopes = torch.cat([alibi_slopes, alibi_slopes_hat])
return alibi_slopes
6 changes: 6 additions & 0 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(
activation: str,
window_size: int,
dropout_p: float = 0.0,
use_flash_attention: bool = False,
softcap: float = None,
use_alibi_slopes: bool = None,
):
super().__init__()

Expand All @@ -81,6 +84,9 @@ def __init__(
bias=False,
is_causal=False,
dropout_p=dropout_p,
use_flash_attention=use_flash_attention,
softcap=softcap,
use_alibi_slopes=use_alibi_slopes,
)

self.mlp = nn.Sequential(
Expand Down
10 changes: 10 additions & 0 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def __init__(
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
dropout_p: float = 0.0,
use_flash_attention: bool = False,
softcap: float = None,
use_alibi_slopes: bool = None,
) -> None:
"""Initialize TransformerProcessor.

Expand All @@ -91,6 +94,10 @@ def __init__(
Activation function, by default "GELU"
dropout_p: float
Dropout probability used for multi-head self attention, default 0.0
softcap : float, optional
Anything > 0 activates softcapping flash attention, by default None
use_alibi_slopes : bool, optional
Use aLiBI option, only used for flash attention, by default None
"""
super().__init__(num_channels=num_channels, num_layers=num_layers)

Expand All @@ -102,6 +109,9 @@ def __init__(
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
use_flash_attention=use_flash_attention,
softcap=softcap,
use_alibi_slopes=use_alibi_slopes,
)

def forward(
Expand Down
10 changes: 10 additions & 0 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
use_flash_attention: bool = False,
softcap: float = 0.0,
use_alibi_slopes: bool = None,
**kwargs,
) -> None:
"""Initialize TransformerProcessor.
Expand All @@ -116,6 +119,10 @@ def __init__(
Activation function, by default "GELU"
dropout_p: float, optional
Dropout probability used for multi-head self attention, default 0.0
softcap : float, optional
Anything > 0 activates softcapping flash attention, by default None
use_alibi_slopes : bool, optional
Use aLiBI option, only used for flash attention, by default None
"""
super().__init__(
num_channels=num_channels,
Expand All @@ -137,6 +144,9 @@ def __init__(
window_size=window_size,
activation=activation,
dropout_p=dropout_p,
use_flash_attention=use_flash_attention,
softcap=softcap,
use_alibi_slopes=use_alibi_slopes,
)

self.offload_layers(cpu_offload)
Expand Down
12 changes: 7 additions & 5 deletions tests/layers/block/test_block_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ class TestTransformerProcessorBlock:
activation=st.sampled_from(["ReLU", "GELU", "Tanh"]),
window_size=st.integers(min_value=1, max_value=512),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
)
@settings(max_examples=10)
def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p):
def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p, softcap):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p, softcap=softcap
)
assert isinstance(block, TransformerProcessorBlock)

Expand All @@ -53,6 +54,7 @@ def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, w
shapes=st.lists(st.integers(min_value=1, max_value=10), min_size=3, max_size=3),
batch_size=st.integers(min_value=1, max_value=40),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
)
@settings(max_examples=10)
def test_forward_output(
Expand All @@ -65,14 +67,14 @@ def test_forward_output(
shapes,
batch_size,
dropout_p,
softcap,
):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p, softcap=softcap
)

x = torch.randn((batch_size, num_channels))

x = torch.randn((batch_size, num_channels)) # .to(torch.float16, non_blocking=True)
output = block.forward(x, shapes, batch_size)
assert isinstance(output, torch.Tensor)
assert output.shape == (batch_size, num_channels)
Expand Down
6 changes: 6 additions & 0 deletions tests/layers/processor/test_transformer_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def transformer_processor_init():
num_heads = 16
mlp_hidden_ratio = 4
dropout_p = 0.1
softcap = 0.5
return (
num_layers,
window_size,
Expand All @@ -32,6 +33,7 @@ def transformer_processor_init():
num_heads,
mlp_hidden_ratio,
dropout_p,
softcap,
)


Expand All @@ -47,6 +49,7 @@ def transformer_processor(transformer_processor_init):
num_heads,
mlp_hidden_ratio,
dropout_p,
softcap,
) = transformer_processor_init
return TransformerProcessor(
num_layers=num_layers,
Expand All @@ -58,6 +61,7 @@ def transformer_processor(transformer_processor_init):
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
dropout_p=dropout_p,
softcap=softcap,
)


Expand All @@ -72,6 +76,7 @@ def test_transformer_processor_init(transformer_processor, transformer_processor
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
_softcap,
) = transformer_processor_init
assert isinstance(transformer_processor, TransformerProcessor)
assert transformer_processor.num_chunks == num_chunks
Expand All @@ -90,6 +95,7 @@ def test_transformer_processor_forward(transformer_processor, transformer_proces
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
_softcap,
) = transformer_processor_init
gridsize = 100
batch_size = 1
Expand Down
Loading
Loading