Skip to content

Commit

Permalink
enable flash attention for robotics transformer, to remove memory bot…
Browse files Browse the repository at this point in the history
…tleneck due to lengthened actions sequences
  • Loading branch information
lucidrains committed Nov 28, 2023
1 parent 2975c1a commit 5d7d65a
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 38 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ Implementation of <a href="https://qtransformer.github.io/">Q-Transformer</a>, S
year = {2023}
}
```

```bibtex
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
```
149 changes: 149 additions & 0 deletions q_transformer/attend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from functools import wraps
from packaging import version

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, reduce

from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, reduce

# constants

FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# helpers

def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner

print_once = once(print)

# helpers

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def maybe_reduce_mask_and(*maybe_masks):
maybe_masks = [*filter(exists, maybe_masks)]

if len(maybe_masks) == 0:
return None

mask, *rest_masks = maybe_masks

for rest_mask in rest_masks:
mask = mask & rest_mask

return mask

# main class

class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False,
flash_config: dict = dict(
enable_flash = True,
enable_math = True,
enable_mem_efficient = True
)
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

if flash:
print_once('using memory efficient attention')

self.flash_config = flash_config

def flash_attn(self, q, k, v, mask = None, attn_mask = None):
_, heads, q_len, dim_head, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L

if exists(mask):
mask = mask.expand(-1, heads, q_len, -1)

mask = maybe_reduce_mask_and(mask, attn_mask)

# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale

with torch.backends.cuda.sdp_kernel(**self.flash_config):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.
)

return out

def forward(self, q, k, v, mask = None, attn_mask = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""

q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

scale = q.shape[-1] ** -0.5

if exists(mask) and mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')

if self.flash:
return self.flash_attn(q, k, v, mask = mask, attn_mask = attn_mask)

# similarity

sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

# key padding mask

if exists(mask):
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

# attention mask

if exists(attn_mask):
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

# attention

attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)

# aggregate values

out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

return out
57 changes: 20 additions & 37 deletions q_transformer/robotic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from einops import pack, unpack, repeat, reduce, rearrange
from einops.layers.torch import Rearrange, Reduce

from q_transformer.attend import Attend

from classifier_free_guidance_pytorch import TextConditioner, AttentionTextConditioner, classifier_free_guidance

# helpers
Expand Down Expand Up @@ -414,18 +416,16 @@ class TransformerAttention(Module):
def __init__(
self,
dim,
causal = False,
dim_head = 64,
dim_context = None,
heads = 8,
num_mem_kv = 4,
norm_context = False,
dropout = 0.1
dropout = 0.1,
flash = True
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.causal = causal
inner_dim = dim_head * heads

dim_context = default(dim_context, dim)
Expand All @@ -436,12 +436,17 @@ def __init__(
self.attn_dropout = nn.Dropout(dropout)

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)

self.num_mem_kv = num_mem_kv
self.mem_kv = None
if num_mem_kv > 0:
self.mem_kv = nn.Parameter(torch.randn(2, num_mem_kv, dim_head))
self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))

self.attend = Attend(
dropout = dropout,
flash = flash
)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
Expand All @@ -453,7 +458,6 @@ def forward(
x,
context = None,
mask = None,
attn_bias = None,
attn_mask = None,
cond_fn: Optional[Callable] = None
):
Expand All @@ -472,46 +476,21 @@ def forward(

q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

q = q * self.scale
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

if exists(self.mem_kv):
mk, mv = map(lambda t: repeat(t, '... -> b ...', b = b), self.mem_kv)

k = torch.cat((mk, k), dim = -2)
v = torch.cat((mv, v), dim = -2)

if exists(attn_bias) and self.num_mem_kv > 0:
attn_bias = F.pad(attn_bias, (self.num_mem_kv, 0), value = 0.)

if exists(mask):
mask = F.pad(mask, (self.num_mem_kv, 0), value = True)

if exists(attn_mask):
attn_mask = F.pad(attn_mask, (self.num_mem_kv, 0), value = True)

sim = einsum('b h i d, b j d -> b h i j', q, k)

if exists(attn_bias):
sim = sim + attn_bias

if exists(attn_mask):
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)

out = einsum('b h i j, b j d -> b h i d', attn, v)
out = self.attend(q, k, v, mask = mask, attn_mask = attn_mask)

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
Expand All @@ -525,13 +504,14 @@ def __init__(
heads = 8,
depth = 6,
attn_dropout = 0.,
ff_dropout = 0.
ff_dropout = 0.,
flash_attn = True
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
TransformerAttention(dim = dim, heads = heads, dropout = attn_dropout),
TransformerAttention(dim = dim, heads = heads, dropout = attn_dropout, flash = flash_attn),
FeedForward(dim = dim, dropout = ff_dropout)
]))

Expand Down Expand Up @@ -649,6 +629,7 @@ def __init__(
concat_action_embeddings = False, # will allow for action embeddings to be concatted just before attention layers - https://arxiv.org/abs/2309.10150 figure 3.
action_dim = 16, # dimension of action embedding, defaults to embedding dimension of maxvit
dueling = False, # https://arxiv.org/abs/1511.06581
flash_attn = True
):
super().__init__()

Expand Down Expand Up @@ -706,7 +687,8 @@ def __init__(
dim = attend_dim,
dim_head = dim_head,
heads = heads,
depth = depth
depth = depth,
flash_attn = flash_attn
)

self.cond_drop_prob = cond_drop_prob
Expand Down Expand Up @@ -829,4 +811,5 @@ def forward(
if self.num_actions == 1:
q_values = rearrange(q_values, '... 1 b -> ... b')

exit()
return q_values
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.10',
version = '0.0.11',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit 5d7d65a

Please sign in to comment.