forked from etched-ai/open-oasis
-
Notifications
You must be signed in to change notification settings - Fork 10
/
attention.py
101 lines (78 loc) · 3.21 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py
"""
from typing import Optional
from collections import namedtuple
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
class TemporalAxialAttention(nn.Module):
def __init__(
self,
dim: int,
heads: int,
dim_head: int,
rotary_emb: RotaryEmbedding,
is_causal: bool = True,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
self.to_out = nn.Linear(self.inner_dim, dim)
self.rotary_emb = rotary_emb
self.is_causal = is_causal
def forward(self, x: torch.Tensor):
B, T, H, W, D = x.shape
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=self.is_causal)
x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
x = x.to(q.dtype)
# linear proj
x = self.to_out(x)
return x
class SpatialAxialAttention(nn.Module):
def __init__(
self,
dim: int,
heads: int,
dim_head: int,
rotary_emb: RotaryEmbedding,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
self.to_out = nn.Linear(self.inner_dim, dim)
self.rotary_emb = rotary_emb
def forward(self, x: torch.Tensor):
B, T, H, W, D = x.shape
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads)
k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads)
v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads)
freqs = self.rotary_emb.get_axial_freqs(H, W)
q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)
# prepare for attn
q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=False)
x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W)
x = x.to(q.dtype)
# linear proj
x = self.to_out(x)
return x