forked from clementfarabet/lua---nnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LA.lua
executable file
·43 lines (31 loc) · 942 Bytes
/
LA.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
require 'nn'
--Based on: http://arxiv.org/pdf/1412.6830v1.pdf
--If input dimension is larger than 1, a reshape is needed after usage.
--Usage:
------------------------------------
-- model:add(LA(4, 3 * 32 * 32))
-- model:add(nn.Reshape(3,32,32))
------------------------------------
function LA(s, inputSize)
local module = nn.Sequential()
local maxmodules = {}
for i = 1,s do
maxmodules[i] = nn.Sequential()
maxmodules[i]:add(nn.MulConstant(-1.0))
maxmodules[i]:add(nn.Add(inputSize,true))
maxmodules[i]:add(nn.ReLU())
maxmodules[i]:add(nn.CMul(inputSize))
end
maxmodules[s+1] = nn.Sequential()
maxmodules[s+1]:add(nn.ReLU())
local catmodule = nn.ConcatTable()
print('number of modules is: '.. #maxmodules)
for i=1,#maxmodules do
catmodule:add(maxmodules[i])
end
module:add(catmodule)
module:add(nn.JoinTable(1))
module:add(nn.Reshape(s + 1,inputSize))
module:add(nn.Sum(1))
return module
end