diff --git a/SpatialConvolutionMap.lua b/SpatialConvolutionMap.lua index 390ace03c..fcce69d1f 100644 --- a/SpatialConvolutionMap.lua +++ b/SpatialConvolutionMap.lua @@ -29,14 +29,14 @@ function nn.tables.random(nin, nout, nto) local tbl = torch.Tensor(nker, 2) local fi = torch.randperm(nin) local frcntr = 1 - local nfi = math.floor(nin/nto) -- number of distinct nto chunks + local nfi = math.floor(nin/nto) -- number of distinct nto chunks local totbl = tbl:select(2,2) local frtbl = tbl:select(2,1) local fitbl = fi:narrow(1, 1, (nfi * nto)) -- part of fi that covers distinct chunks local ufrtbl= frtbl:unfold(1, nto, nto) local utotbl= totbl:unfold(1, nto, nto) local ufitbl= fitbl:unfold(1, nto, nto) - + -- start filling frtbl for i=1,nout do -- fro each unit in target map ufrtbl:select(1,i):copy(ufitbl:select(1,frcntr)) @@ -52,37 +52,6 @@ function nn.tables.random(nin, nout, nto) return tbl end -local function constructTableRev(conMatrix) - local conMatrixL = conMatrix:type('torch.LongTensor') - -- Construct reverse lookup connection table - local thickness = conMatrixL:select(2,2):max() - -- approximate fanin check - if (#conMatrixL)[1] % thickness == 0 then - -- do a proper fanin check and set revTable - local fanin = (#conMatrixL)[1] / thickness - local revTable = torch.Tensor(thickness, fanin, 2) - for ii=1,thickness do - local tempf = fanin - for jj=1,(#conMatrixL)[1] do - if conMatrixL[jj][2] == ii then - if tempf <= 0 then break end - revTable[ii][tempf][1] = conMatrixL[jj][1] - revTable[ii][tempf][2] = jj - tempf = tempf - 1 - end - end - if tempf ~= 0 then - fanin = -1 - break - end - end - if fanin ~= -1 then - return revTable - end - end - return {} -end - function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH) parent.__init(self) @@ -94,14 +63,13 @@ function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH) self.dW = dW self.dH = dH self.connTable = conMatrix - self.connTableRev = constructTableRev(conMatrix) self.nInputPlane = self.connTable:select(2,1):max() self.nOutputPlane = self.connTable:select(2,2):max() self.weight = torch.Tensor(self.connTable:size(1), kH, kW) self.bias = torch.Tensor(self.nOutputPlane) self.gradWeight = torch.Tensor(self.connTable:size(1), kH, kW) self.gradBias = torch.Tensor(self.nOutputPlane) - + self:reset() end