diff --git a/README.md b/README.md index 9f52a31..024eb14 100644 --- a/README.md +++ b/README.md @@ -32,3 +32,12 @@ Implementation of Q-Transformer, S year = {2023} } ``` + +```bibtex +@inproceedings{dao2022flashattention, + title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, + author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + booktitle = {Advances in Neural Information Processing Systems}, + year = {2022} +} +``` diff --git a/q_transformer/attend.py b/q_transformer/attend.py new file mode 100644 index 0000000..cd19f09 --- /dev/null +++ b/q_transformer/attend.py @@ -0,0 +1,149 @@ +from functools import wraps +from packaging import version + +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, reduce + +from functools import wraps +from packaging import version +from collections import namedtuple + +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, reduce + +# constants + +FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) + +# helpers + +def once(fn): + called = False + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + return inner + +print_once = once(print) + +# helpers + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def maybe_reduce_mask_and(*maybe_masks): + maybe_masks = [*filter(exists, maybe_masks)] + + if len(maybe_masks) == 0: + return None + + mask, *rest_masks = maybe_masks + + for rest_mask in rest_masks: + mask = mask & rest_mask + + return mask + +# main class + +class Attend(nn.Module): + def __init__( + self, + dropout = 0., + flash = False, + flash_config: dict = dict( + enable_flash = True, + enable_math = True, + enable_mem_efficient = True + ) + ): + super().__init__() + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.flash = flash + assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' + + if flash: + print_once('using memory efficient attention') + + self.flash_config = flash_config + + def flash_attn(self, q, k, v, mask = None, attn_mask = None): + _, heads, q_len, dim_head, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device + + # Check if mask exists and expand to compatible shape + # The mask is B L, so it would have to be expanded to B H N L + + if exists(mask): + mask = mask.expand(-1, heads, q_len, -1) + + mask = maybe_reduce_mask_and(mask, attn_mask) + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + + with torch.backends.cuda.sdp_kernel(**self.flash_config): + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask = mask, + dropout_p = self.dropout if self.training else 0. + ) + + return out + + def forward(self, q, k, v, mask = None, attn_mask = None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + q_len, k_len, device = q.shape[-2], k.shape[-2], q.device + + scale = q.shape[-1] ** -0.5 + + if exists(mask) and mask.ndim != 4: + mask = rearrange(mask, 'b j -> b 1 1 j') + + if self.flash: + return self.flash_attn(q, k, v, mask = mask, attn_mask = attn_mask) + + # similarity + + sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale + + # key padding mask + + if exists(mask): + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # attention mask + + if exists(attn_mask): + sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, b h j d -> b h i d", attn, v) + + return out diff --git a/q_transformer/robotic_transformer.py b/q_transformer/robotic_transformer.py index 7784fb5..0dc684a 100644 --- a/q_transformer/robotic_transformer.py +++ b/q_transformer/robotic_transformer.py @@ -12,6 +12,8 @@ from einops import pack, unpack, repeat, reduce, rearrange from einops.layers.torch import Rearrange, Reduce +from q_transformer.attend import Attend + from classifier_free_guidance_pytorch import TextConditioner, AttentionTextConditioner, classifier_free_guidance # helpers @@ -414,18 +416,16 @@ class TransformerAttention(Module): def __init__( self, dim, - causal = False, dim_head = 64, dim_context = None, heads = 8, num_mem_kv = 4, norm_context = False, - dropout = 0.1 + dropout = 0.1, + flash = True ): super().__init__() self.heads = heads - self.scale = dim_head ** -0.5 - self.causal = causal inner_dim = dim_head * heads dim_context = default(dim_context, dim) @@ -436,12 +436,17 @@ def __init__( self.attn_dropout = nn.Dropout(dropout) self.to_q = nn.Linear(dim, inner_dim, bias = False) - self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False) + self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False) self.num_mem_kv = num_mem_kv self.mem_kv = None if num_mem_kv > 0: - self.mem_kv = nn.Parameter(torch.randn(2, num_mem_kv, dim_head)) + self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head)) + + self.attend = Attend( + dropout = dropout, + flash = flash + ) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), @@ -453,7 +458,6 @@ def forward( x, context = None, mask = None, - attn_bias = None, attn_mask = None, cond_fn: Optional[Callable] = None ): @@ -472,9 +476,7 @@ def forward( q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1) - q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) - - q = q * self.scale + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) if exists(self.mem_kv): mk, mv = map(lambda t: repeat(t, '... -> b ...', b = b), self.mem_kv) @@ -482,36 +484,13 @@ def forward( k = torch.cat((mk, k), dim = -2) v = torch.cat((mv, v), dim = -2) - if exists(attn_bias) and self.num_mem_kv > 0: - attn_bias = F.pad(attn_bias, (self.num_mem_kv, 0), value = 0.) - if exists(mask): mask = F.pad(mask, (self.num_mem_kv, 0), value = True) if exists(attn_mask): attn_mask = F.pad(attn_mask, (self.num_mem_kv, 0), value = True) - sim = einsum('b h i d, b j d -> b h i j', q, k) - - if exists(attn_bias): - sim = sim + attn_bias - - if exists(attn_mask): - sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) - - if exists(mask): - mask = rearrange(mask, 'b j -> b 1 1 j') - sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) - - if self.causal: - i, j = sim.shape[-2:] - causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1) - sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) - - attn = sim.softmax(dim = -1) - attn = self.attn_dropout(attn) - - out = einsum('b h i j, b j d -> b h i d', attn, v) + out = self.attend(q, k, v, mask = mask, attn_mask = attn_mask) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) @@ -525,13 +504,14 @@ def __init__( heads = 8, depth = 6, attn_dropout = 0., - ff_dropout = 0. + ff_dropout = 0., + flash_attn = True ): super().__init__() self.layers = ModuleList([]) for _ in range(depth): self.layers.append(ModuleList([ - TransformerAttention(dim = dim, heads = heads, dropout = attn_dropout), + TransformerAttention(dim = dim, heads = heads, dropout = attn_dropout, flash = flash_attn), FeedForward(dim = dim, dropout = ff_dropout) ])) @@ -649,6 +629,7 @@ def __init__( concat_action_embeddings = False, # will allow for action embeddings to be concatted just before attention layers - https://arxiv.org/abs/2309.10150 figure 3. action_dim = 16, # dimension of action embedding, defaults to embedding dimension of maxvit dueling = False, # https://arxiv.org/abs/1511.06581 + flash_attn = True ): super().__init__() @@ -706,7 +687,8 @@ def __init__( dim = attend_dim, dim_head = dim_head, heads = heads, - depth = depth + depth = depth, + flash_attn = flash_attn ) self.cond_drop_prob = cond_drop_prob @@ -829,4 +811,5 @@ def forward( if self.num_actions == 1: q_values = rearrange(q_values, '... 1 b -> ... b') + exit() return q_values diff --git a/setup.py b/setup.py index 516ba71..c47a0d9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.0.10', + version = '0.0.11', license='MIT', description = 'Q-Transformer', author = 'Phil Wang',