Skip to content

Commit

Permalink
Merge pull request torch#269 from georgevdd/master
Browse files Browse the repository at this point in the history
Remove unused and expensive init logic from nn.SpatialConvolutionMap.
  • Loading branch information
koraykv committed May 20, 2015
2 parents 88ea556 + a28545a commit b3f7bcc
Showing 1 changed file with 3 additions and 35 deletions.
38 changes: 3 additions & 35 deletions SpatialConvolutionMap.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ function nn.tables.random(nin, nout, nto)
local tbl = torch.Tensor(nker, 2)
local fi = torch.randperm(nin)
local frcntr = 1
local nfi = math.floor(nin/nto) -- number of distinct nto chunks
local nfi = math.floor(nin/nto) -- number of distinct nto chunks
local totbl = tbl:select(2,2)
local frtbl = tbl:select(2,1)
local fitbl = fi:narrow(1, 1, (nfi * nto)) -- part of fi that covers distinct chunks
local ufrtbl= frtbl:unfold(1, nto, nto)
local utotbl= totbl:unfold(1, nto, nto)
local ufitbl= fitbl:unfold(1, nto, nto)

-- start filling frtbl
for i=1,nout do -- fro each unit in target map
ufrtbl:select(1,i):copy(ufitbl:select(1,frcntr))
Expand All @@ -52,37 +52,6 @@ function nn.tables.random(nin, nout, nto)
return tbl
end

local function constructTableRev(conMatrix)
local conMatrixL = conMatrix:type('torch.LongTensor')
-- Construct reverse lookup connection table
local thickness = conMatrixL:select(2,2):max()
-- approximate fanin check
if (#conMatrixL)[1] % thickness == 0 then
-- do a proper fanin check and set revTable
local fanin = (#conMatrixL)[1] / thickness
local revTable = torch.Tensor(thickness, fanin, 2)
for ii=1,thickness do
local tempf = fanin
for jj=1,(#conMatrixL)[1] do
if conMatrixL[jj][2] == ii then
if tempf <= 0 then break end
revTable[ii][tempf][1] = conMatrixL[jj][1]
revTable[ii][tempf][2] = jj
tempf = tempf - 1
end
end
if tempf ~= 0 then
fanin = -1
break
end
end
if fanin ~= -1 then
return revTable
end
end
return {}
end

function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH)
parent.__init(self)

Expand All @@ -94,14 +63,13 @@ function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH)
self.dW = dW
self.dH = dH
self.connTable = conMatrix
self.connTableRev = constructTableRev(conMatrix)
self.nInputPlane = self.connTable:select(2,1):max()
self.nOutputPlane = self.connTable:select(2,2):max()
self.weight = torch.Tensor(self.connTable:size(1), kH, kW)
self.bias = torch.Tensor(self.nOutputPlane)
self.gradWeight = torch.Tensor(self.connTable:size(1), kH, kW)
self.gradBias = torch.Tensor(self.nOutputPlane)

self:reset()
end

Expand Down

0 comments on commit b3f7bcc

Please sign in to comment.