Skip to content

Commit

Permalink
feat: FlashMultiHeadSelfAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Oct 25, 2024
1 parent e335d18 commit 663fea0
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup

try:
from flash_attn import flash_attn_func as attn_func
except ImportError:
from torch.nn.functional import scaled_dot_product_attention as attn_func

_FLASH_ATTENTION_AVAILABLE = False
else:
_FLASH_ATTENTION_AVAILABLE = True

from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence

Expand Down Expand Up @@ -57,13 +48,19 @@ def __init__(
self.is_causal = is_causal

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")

self.projection = nn.Linear(embed_dim, embed_dim, bias=True)

self.attention = self.get_attention_function()

def get_attention_function(self):
from torch.nn.functional import scaled_dot_product_attention

return scaled_dot_product_attention

def attend(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
return self.attention(query, key, value, is_causal=False) # expects (batch heads grid variable) format

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
Expand All @@ -89,24 +86,31 @@ def forward(
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
dropout_p = self.dropout_p if self.training else 0.0

if _FLASH_ATTENTION_AVAILABLE:
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)
out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
else:
out = self.attention(
query,
key,
value,
is_causal=False,
dropout_p=dropout_p,
) # expects (batch heads grid variable) format
out = self.attend(query, key, value)

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")

out = self.projection(out)

return out


class FlashMultiHeadSelfAttention(MultiHeadSelfAttention):
"""Multi Head Self Attention Pytorch Layer."""

def get_attention_function(self):
from flash_attn import flash_attn_func

return flash_attn_func

def attend(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:

query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)

out = self.attention(query, key, value, causal=False, window_size=self.window_size)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")

return out

0 comments on commit 663fea0

Please sign in to comment.