From dbd4b5a1c52d40b6b3ba1687cc8ca7d1efa943eb Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 12 Dec 2019 13:57:19 -0800 Subject: [PATCH] add a positional encoding MLP to the mix, helps with low-d curled up data potentially --- nflib/nets.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/nflib/nets.py b/nflib/nets.py index 673fb9b..c2f8ea4 100644 --- a/nflib/nets.py +++ b/nflib/nets.py @@ -20,6 +20,21 @@ def __init__(self, n): def forward(self, x): return self.p.expand(x.size(0), self.p.size(1)) +class PositionalEncoder(nn.Module): + """ + Each dimension of the input gets expanded out with sins/coses + to "carve" out the space. Useful in low-dimensional cases with + tightly "curled up" data. + """ + def __init__(self, freqs=(.5,1,2,4,8)): + super().__init__() + self.freqs = freqs + + def forward(self, x): + sines = [torch.sin(x * f) for f in self.freqs] + coses = [torch.cos(x * f) for f in self.freqs] + out = torch.cat(sines + coses, dim=1) + return out class MLP(nn.Module): """ a simple 4-layer MLP """ @@ -38,6 +53,21 @@ def __init__(self, nin, nout, nh): def forward(self, x): return self.net(x) +class PosEncMLP(nn.Module): + """ + Position Encoded MLP, where the first layer performs position encoding. + Each dimension of the input gets transformed to len(freqs)*2 dimensions + using a fixed transformation of sin/cos of given frequencies. + """ + def __init__(self, nin, nout, nh, freqs=(.5,1,2,4,8)): + super().__init__() + self.net = nn.Sequential( + PositionalEncoder(freqs), + MLP(nin * len(freqs) * 2, nout, nh), + ) + def forward(self, x): + return self.net(x) + class ARMLP(nn.Module): """ a 4-layer auto-regressive MLP, wrapper around MADE net """