From 663fea035488c7915e42b9f7ff96087fef306585 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Tue, 17 Sep 2024 13:07:47 +0100 Subject: [PATCH] feat: FlashMultiHeadSelfAttention --- src/anemoi/models/layers/attention.py | 58 ++++++++++++++------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index d7f54920..886ec046 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -16,15 +16,6 @@ from torch import nn from torch.distributed.distributed_c10d import ProcessGroup -try: - from flash_attn import flash_attn_func as attn_func -except ImportError: - from torch.nn.functional import scaled_dot_product_attention as attn_func - - _FLASH_ATTENTION_AVAILABLE = False -else: - _FLASH_ATTENTION_AVAILABLE = True - from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence @@ -57,13 +48,19 @@ def __init__( self.is_causal = is_causal self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) - self.attention = attn_func - - if not _FLASH_ATTENTION_AVAILABLE: - LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention") self.projection = nn.Linear(embed_dim, embed_dim, bias=True) + self.attention = self.get_attention_function() + + def get_attention_function(self): + from torch.nn.functional import scaled_dot_product_attention + + return scaled_dot_product_attention + + def attend(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: + return self.attention(query, key, value, is_causal=False) # expects (batch heads grid variable) format + def forward( self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None ) -> Tensor: @@ -89,20 +86,7 @@ def forward( value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) dropout_p = self.dropout_p if self.training else 0.0 - if _FLASH_ATTENTION_AVAILABLE: - query, key, value = ( - einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value) - ) - out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p) - out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars") - else: - out = self.attention( - query, - key, - value, - is_causal=False, - dropout_p=dropout_p, - ) # expects (batch heads grid variable) format + out = self.attend(query, key, value) out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)") @@ -110,3 +94,23 @@ def forward( out = self.projection(out) return out + + +class FlashMultiHeadSelfAttention(MultiHeadSelfAttention): + """Multi Head Self Attention Pytorch Layer.""" + + def get_attention_function(self): + from flash_attn import flash_attn_func + + return flash_attn_func + + def attend(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: + + query, key, value = ( + einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value) + ) + + out = self.attention(query, key, value, causal=False, window_size=self.window_size) + out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars") + + return out