-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 4df3893
Showing
127 changed files
with
10,444 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
local Abs, parent = torch.class('nn.Abs', 'nn.Module') | ||
|
||
function Abs:__init() | ||
parent.__init(self) | ||
end | ||
|
||
function Abs:updateOutput(input) | ||
input.nn.Abs_updateOutput(self, input) | ||
return self.output | ||
end | ||
|
||
function Abs:updateGradInput(input, gradOutput) | ||
input.nn.Abs_updateGradInput(self, input, gradOutput) | ||
return self.gradInput | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
local AbsCriterion, parent = torch.class('nn.AbsCriterion', 'nn.Criterion') | ||
|
||
function AbsCriterion:__init() | ||
parent.__init(self) | ||
self.sizeAverage = true | ||
end | ||
|
||
function AbsCriterion:updateOutput(input, target) | ||
return input.nn.AbsCriterion_updateOutput(self, input, target) | ||
end | ||
|
||
function AbsCriterion:updateGradInput(input, target) | ||
return input.nn.AbsCriterion_updateGradInput(self, input, target) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
local Add, parent = torch.class('nn.Add', 'nn.Module') | ||
|
||
function Add:__init(inputSize,scalar) | ||
parent.__init(self) | ||
|
||
local size = inputSize | ||
if scalar then size=1 end | ||
self.bias = torch.Tensor(size) | ||
self.gradBias = torch.Tensor(size) | ||
|
||
-- state | ||
self.gradInput:resize(inputSize) | ||
self.output:resize(inputSize) | ||
|
||
self:reset() | ||
end | ||
|
||
function Add:reset(stdv) | ||
if stdv then | ||
stdv = stdv * math.sqrt(3) | ||
else | ||
stdv = 1./math.sqrt(self.bias:size(1)) | ||
end | ||
|
||
for i=1,self.bias:size(1) do | ||
self.bias[i] = torch.uniform(-stdv, stdv) | ||
end | ||
end | ||
|
||
function Add:updateOutput(input) | ||
self.output:copy(input); | ||
if self.gradBias:size(1)==1 then | ||
self.output:add(self.bias[1]); | ||
else | ||
self.output:add(self.bias); | ||
end | ||
return self.output | ||
end | ||
|
||
function Add:updateGradInput(input, gradOutput) | ||
if self.gradInput then | ||
self.gradInput:copy(gradOutput) | ||
return self.gradInput | ||
end | ||
end | ||
|
||
function Add:accGradParameters(input, gradOutput, scale) | ||
scale = scale or 1 | ||
if self.gradBias:size(1) == 1 then | ||
self.gradBias[1] = self.gradBias[1] + scale*gradOutput:sumall(); | ||
else | ||
self.gradBias:add(scale, gradOutput) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
|
||
local CAddTable, parent = torch.class('nn.CAddTable', 'nn.Module') | ||
|
||
function CAddTable:__init() | ||
parent.__init(self) | ||
self.gradInput = {} | ||
end | ||
|
||
function CAddTable:updateOutput(input) | ||
self.output:resizeAs(input[1]):copy(input[1]) | ||
for i=2,#input do | ||
self.output:add(input[i]) | ||
end | ||
return self.output | ||
end | ||
|
||
function CAddTable:updateGradInput(input, gradOutput) | ||
for i=1,#input do | ||
self.gradInput[i] = self.gradInput[i] or torch.Tensor() | ||
self.gradInput[i]:resizeAs(input[i]) | ||
self.gradInput[i]:copy(gradOutput) | ||
end | ||
return self.gradInput | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
local CDivTable, parent = torch.class('nn.CDivTable', 'nn.Module') | ||
|
||
function CDivTable:__init() | ||
parent.__init(self) | ||
self.gradInput = {} | ||
end | ||
|
||
function CDivTable:updateOutput(input) | ||
self.output:resizeAs(input[1]):copy(input[1]) | ||
self.output:cdiv(input[2]) | ||
return self.output | ||
end | ||
|
||
function CDivTable:updateGradInput(input, gradOutput) | ||
self.gradInput[1] = self.gradInput[1] or torch.Tensor() | ||
self.gradInput[2] = self.gradInput[2] or torch.Tensor() | ||
self.gradInput[1]:resizeAs(input[1]):copy(gradOutput):cdiv(input[2]) | ||
self.gradInput[2]:resizeAs(input[2]):zero():addcdiv(-1,self.gradInput[1],input[2]):cmul(input[1]) | ||
return self.gradInput | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
SET(src init.c) | ||
|
||
FILE(GLOB luasrc *.lua) | ||
SET(luasrc ${luasrc} test/test.lua) | ||
|
||
ADD_TORCH_PACKAGE(nn "${src}" "${luasrc}" "Machine Learning") | ||
ADD_TORCH_DOK(dok nn "Machine Learning" "Neural Networks" 3.1) | ||
|
||
TARGET_LINK_LIBRARIES(nn luaT TH) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
local CMul, parent = torch.class('nn.CMul', 'nn.Module') | ||
|
||
function CMul:__init(inputSize) | ||
parent.__init(self) | ||
|
||
self.weight = torch.Tensor(inputSize) | ||
self.gradWeight = torch.Tensor(inputSize) | ||
|
||
-- state | ||
self.gradInput:resize(inputSize) | ||
self.output:resize(inputSize) | ||
|
||
self:reset() | ||
end | ||
|
||
function CMul:reset() | ||
self.weight:fill(1) | ||
end | ||
|
||
function CMul:updateOutput(input) | ||
self.output:copy(input); | ||
self.output:cmul(self.weight); | ||
return self.output | ||
end | ||
|
||
function CMul:updateGradInput(input, gradOutput) | ||
if self.gradInput then | ||
self.gradInput:zero() | ||
self.gradInput:addcmul(1, self.weight, gradOutput) | ||
return self.gradInput | ||
end | ||
end | ||
|
||
function CMul:accGradParameters(input, gradOutput, scale) | ||
self.gradWeight:addcmul(scale or 1, input, gradOutput) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
|
||
local CMulTable, parent = torch.class('nn.CMulTable', 'nn.Module') | ||
|
||
function CMulTable:__init() | ||
parent.__init(self) | ||
self.gradInput = {} | ||
end | ||
|
||
function CMulTable:updateOutput(input) | ||
self.output:resizeAs(input[1]):copy(input[1]) | ||
for i=2,#input do | ||
self.output:cmul(input[i]) | ||
end | ||
return self.output | ||
end | ||
|
||
function CMulTable:updateGradInput(input, gradOutput) | ||
local tout = torch.Tensor():resizeAs(self.output) | ||
for i=1,#input do | ||
self.gradInput[i] = self.gradInput[i] or torch.Tensor() | ||
self.gradInput[i]:resizeAs(input[i]):copy(gradOutput) | ||
tout:copy(self.output):cdiv(input[i]) | ||
self.gradInput[i]:cmul(tout) | ||
end | ||
return self.gradInput | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
local CSubTable, parent = torch.class('nn.CSubTable', 'nn.Module') | ||
|
||
function CSubTable:__init() | ||
parent.__init(self) | ||
self.gradInput = {} | ||
end | ||
|
||
function CSubTable:updateOutput(input) | ||
self.output:resizeAs(input[1]):copy(input[1]) | ||
self.output:add(-1,input[2]) | ||
return self.output | ||
end | ||
|
||
function CSubTable:updateGradInput(input, gradOutput) | ||
self.gradInput[1] = self.gradInput[1] or torch.Tensor() | ||
self.gradInput[2] = self.gradInput[2] or torch.Tensor() | ||
self.gradInput[1]:resizeAs(input[1]):copy(gradOutput) | ||
self.gradInput[2]:resizeAs(input[1]):copy(gradOutput):mul(-1) | ||
return self.gradInput | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
local ClassNLLCriterion, parent = torch.class('nn.ClassNLLCriterion', 'nn.Criterion') | ||
|
||
function ClassNLLCriterion:__init() | ||
parent.__init(self) | ||
self.sizeAverage = true | ||
end | ||
|
||
function ClassNLLCriterion:updateOutput(input, target) | ||
if input:dim() == 1 then | ||
self.output = -input[target] | ||
elseif input:dim() == 2 then | ||
local output = 0 | ||
for i=1,target:size(1) do | ||
output = output - input[i][target[i]] | ||
end | ||
if self.sizeAverage then | ||
output = output / target:size(1) | ||
end | ||
self.output = output | ||
else | ||
error('matrix or vector expected') | ||
end | ||
return self.output | ||
end | ||
|
||
function ClassNLLCriterion:updateGradInput(input, target) | ||
self.gradInput:resizeAs(input) | ||
self.gradInput:zero() | ||
|
||
if input:dim() == 1 then | ||
self.gradInput[target] = -1 | ||
else | ||
local z = -1 | ||
if self.sizeAverage then | ||
z = z / target:size(1) | ||
end | ||
local gradInput = self.gradInput | ||
for i=1,target:size(1) do | ||
gradInput[i][target[i]] = z | ||
end | ||
end | ||
|
||
return self.gradInput | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
local Concat, parent = torch.class('nn.Concat', 'nn.Module') | ||
|
||
function Concat:__init(dimension) | ||
parent.__init(self) | ||
self.modules = {} | ||
self.size = torch.LongStorage() | ||
self.dimension = dimension | ||
end | ||
|
||
function Concat:add(module) | ||
table.insert(self.modules, module) | ||
return self | ||
end | ||
|
||
function Concat:get(index) | ||
return self.modules[index] | ||
end | ||
|
||
function Concat:updateOutput(input) | ||
for i=1,#self.modules do | ||
local currentOutput = self.modules[i]:updateOutput(input) | ||
|
||
if i == 1 then | ||
self.size:resize(currentOutput:dim()):copy(currentOutput:size()) | ||
else | ||
self.size[self.dimension] = self.size[self.dimension] + currentOutput:size(self.dimension) | ||
end | ||
end | ||
self.output:resize(self.size) | ||
|
||
local offset = 1 | ||
for _,module in ipairs(self.modules) do | ||
local currentOutput = module:updateOutput(input) | ||
self.output:narrow(self.dimension, offset, currentOutput:size(self.dimension)):copy(currentOutput) | ||
offset = offset + currentOutput:size(self.dimension) | ||
end | ||
return self.output | ||
end | ||
|
||
function Concat:updateGradInput(input, gradOutput) | ||
self.gradInput:resizeAs(input) | ||
|
||
local offset = 1 | ||
for i,module in ipairs(self.modules) do | ||
local currentOutput = module.output | ||
local currentGradInput = module:updateGradInput(input, gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension))) | ||
|
||
if i==1 then | ||
self.gradInput:copy(currentGradInput) | ||
else | ||
self.gradInput:add(currentGradInput) | ||
end | ||
offset = offset + currentOutput:size(self.dimension) | ||
end | ||
return self.gradInput | ||
end | ||
|
||
function Concat:accGradParameters(input, gradOutput, scale) | ||
scale = scale or 1 | ||
local offset = 1 | ||
for i,module in ipairs(self.modules) do | ||
local currentOutput = module.output | ||
local currentGradInput = module:accGradParameters(input, | ||
gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), | ||
scale) | ||
offset = offset + currentOutput:size(self.dimension) | ||
end | ||
end | ||
|
||
function Concat:accUpdateGradParameters(input, gradOutput, lr) | ||
local offset = 1 | ||
for i,module in ipairs(self.modules) do | ||
local currentOutput = module.output | ||
local currentGradInput = module:accUpdateGradParameters(input, | ||
gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), | ||
lr) | ||
offset = offset + currentOutput:size(self.dimension) | ||
end | ||
end | ||
|
||
function Concat:zeroGradParameters() | ||
for _,module in ipairs(self.modules) do | ||
module:zeroGradParameters() | ||
end | ||
end | ||
|
||
function Concat:updateParameters(learningRate) | ||
for _,module in ipairs(self.modules) do | ||
module:updateParameters(learningRate) | ||
end | ||
end | ||
|
||
function Concat:share(mlp,...) | ||
for i=1,#self.modules do | ||
self.modules[i]:share(mlp.modules[i],...); | ||
end | ||
end | ||
|
||
function Concat:parameters() | ||
local function tinsert(to, from) | ||
if type(from) == 'table' then | ||
for i=1,#from do | ||
tinsert(to,from[i]) | ||
end | ||
else | ||
table.insert(to,from) | ||
end | ||
end | ||
local w = {} | ||
local gw = {} | ||
for i=1,#self.modules do | ||
local mw,mgw = self.modules[i]:parameters() | ||
if mw then | ||
tinsert(w,mw) | ||
tinsert(gw,mgw) | ||
end | ||
end | ||
return w,gw | ||
end |
Oops, something went wrong.