Skip to content

Commit

Permalink
Hack for flash-attn support for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Oct 15, 2024
1 parent f31f867 commit 54b1fd5
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,13 @@ def forward(
dropout_p = self.dropout_p if self.training else 0.0

if _FLASH_ATTENTION_AVAILABLE:
import torch
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)
query = query.to(torch.bfloat16, non_blocking=True).to(device="cuda")
key = key.to(torch.bfloat16, non_blocking=True).to(device="cuda")
value = value.to(torch.bfloat16, non_blocking=True).to(device="cuda")
out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
else:
Expand Down

0 comments on commit 54b1fd5

Please sign in to comment.