Skip to content

Commit

Permalink
layernorm is enough
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 27, 2023
1 parent adf9449 commit 59265b9
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Implementation of <a href="https://qtransformer.github.io/">Q-Transformer</a>, S
## Todo

- [x] first work way towards single action support
- [x] offer batchnorm-less variant of maxvit, as done in SOTA weather model metnet3

- [ ] build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last)
- [ ] do n-step Q learning, even though not that big of improvement
Expand All @@ -20,7 +21,6 @@ Implementation of <a href="https://qtransformer.github.io/">Q-Transformer</a>, S
- [ ] improvise a cross attention variant instead of concatenating previous actions? (could have wrong intuition here)
- [ ] see if the main idea in this paper is applicable to language models
- [ ] consult some RL experts and figure out if there are any new headways into resolving <a href="https://www.cs.toronto.edu/~cebly/Papers/CONQUR_ICML_2020_camera_ready.pdf">delusional bias</a>
- [ ] offer batchnorm-less variant of maxvit, as done in SOTA weather model metnet3

## Citations

Expand Down
38 changes: 29 additions & 9 deletions q_transformer/robotic_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import cache
from functools import partial, cache

import torch
import torch.nn.functional as F
Expand All @@ -12,8 +12,6 @@
from einops import pack, unpack, repeat, reduce, rearrange
from einops.layers.torch import Rearrange, Reduce

from functools import partial

from classifier_free_guidance_pytorch import TextConditioner, AttentionTextConditioner, classifier_free_guidance

# helpers
Expand Down Expand Up @@ -43,6 +41,20 @@ def MaybeSyncBatchnorm2d(is_distributed = None):
is_distributed = default(is_distributed, get_is_distributed())
return nn.SyncBatchNorm if is_distributed else nn.BatchNorm2d

# channel layernorm

class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(dim, 1, 1))
self.beta = nn.Parameter(torch.zeros(dim, 1, 1))

def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * var.clamp(min = self.eps).rsqrt() * self.gamma + self.beta

# sinusoidal positions

def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch.float32):
Expand Down Expand Up @@ -87,6 +99,7 @@ def __init__(self, dim, mult = 4, dropout = 0.):
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x, cond_fn = None):
x = self.norm(x)

Expand Down Expand Up @@ -149,22 +162,27 @@ def MBConv(
expansion_rate = 4,
shrinkage_rate = 0.25,
dropout = 0.,
is_distributed = None
is_distributed = None,
use_layernorm = True
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
batchnorm_klass = MaybeSyncBatchnorm2d(is_distributed)

if use_layernorm:
norm_klass = ChanLayerNorm
else:
norm_klass = MaybeSyncBatchnorm2d(is_distributed)

net = nn.Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
batchnorm_klass(hidden_dim),
norm_klass(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
batchnorm_klass(hidden_dim),
norm_klass(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
batchnorm_klass(dim_out)
norm_klass(dim_out)
)

if dim_in == dim_out and not downsample:
Expand Down Expand Up @@ -288,6 +306,7 @@ def __init__(
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
use_layernorm = True,
dropout = 0.1,
channels = 3
):
Expand Down Expand Up @@ -334,7 +353,8 @@ def __init__(
layer_dim,
downsample = is_first,
expansion_rate = mbconv_expansion_rate,
shrinkage_rate = mbconv_shrinkage_rate
shrinkage_rate = mbconv_shrinkage_rate,
use_layernorm = use_layernorm
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
Residual(Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit 59265b9

Please sign in to comment.