Skip to content

Commit

Permalink
feat: add flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider authored Jan 18, 2023
2 parents db80168 + f3ef66e commit 3a88ef4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 14 deletions.
17 changes: 4 additions & 13 deletions a_unet/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import torch
import torch.nn.functional as F
import xformers
import xformers.ops
from einops import pack, rearrange, reduce, repeat, unpack
from torch import Tensor, einsum, nn
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -228,28 +230,18 @@ def ConvNextV2Block(dim: int, channels: int) -> nn.Module:


def AttentionBase(features: int, head_features: int, num_heads: int) -> nn.Module:
scale = head_features**-0.5
mid_features = head_features * num_heads
to_out = nn.Linear(in_features=mid_features, out_features=features, bias=False)

def forward(
q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
) -> Tensor:
h = num_heads
# Split heads
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
# Compute similarity matrix and add eventual mask
sim = einsum("... n d, ... m d -> ... n m", q, k) * scale
# Get attention matrix with softmax
attn = sim.softmax(dim=-1)
# Compute values
out = einsum("... n m, ... m d -> ... n d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
# Use memory efficient attention
out = xformers.ops.memory_efficient_attention(q, k, v)
return to_out(out)

return Module([to_out], forward)


def LinearAttentionBase(features: int, head_features: int, num_heads: int) -> nn.Module:
scale = head_features**-0.5
mid_features = head_features * num_heads
Expand All @@ -270,7 +262,6 @@ def forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor:

return Module([to_out], forward)


def FixedEmbedding(max_length: int, features: int):
embedding = nn.Embedding(max_length, features)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
author_email="[email protected]",
url="https://github.com/archinetai/a-unet",
keywords=["artificial intelligence", "deep learning"],
install_requires=["torch>=1.6", "data-science-types>=0.2", "einops>=0.6.0"],
install_requires=["torch>=1.6", "data-science-types>=0.2", "einops>=0.6.0", "xformers>=0.0.13"],
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
Expand Down

0 comments on commit 3a88ef4

Please sign in to comment.