Skip to content

Commit

Permalink
Replicate batchMode
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed May 14, 2015
1 parent 28bb486 commit 2ac59f9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
21 changes: 13 additions & 8 deletions Replicate.lua
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
local Replicate, parent = torch.class('nn.Replicate','nn.Module')

function Replicate:__init(nf, dim)
function Replicate:__init(nf, dim, ndim)
parent.__init(self)
self.nfeatures = nf
self.dim = dim or 1
self.ndim = ndim
assert(self.dim > 0, "Can only replicate across positive integer dimensions.")
end

Expand All @@ -13,20 +14,22 @@ function Replicate:updateOutput(input)
self.dim <= input:dim()+1,
"Not enough input dimensions to replicate along dimension " ..
tostring(self.dim) .. ".")
local batchOffset = self.ndim and input:dim() > self.ndim and 1 or 0
local rdim = self.dim + batchOffset
local sz = torch.LongStorage(input:dim()+1)
sz[self.dim] = self.nfeatures
sz[rdim] = self.nfeatures
for i = 1,input:dim() do
local offset = 0
if i >= self.dim then
if i >= rdim then
offset = 1
end
sz[i+offset] = input:size(i)
end
local st = torch.LongStorage(input:dim()+1)
st[self.dim] = 0
st[rdim] = 0
for i = 1,input:dim() do
local offset = 0
if i >= self.dim then
if i >= rdim then
offset = 1
end
st[i+offset] = input:stride(i)
Expand All @@ -37,16 +40,18 @@ end

function Replicate:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(input):zero()
local batchOffset = self.ndim and input:dim() > self.ndim and 1 or 0
local rdim = self.dim + batchOffset
local sz = torch.LongStorage(input:dim()+1)
sz[self.dim] = 1
sz[rdim] = 1
for i = 1,input:dim() do
local offset = 0
if i >= self.dim then
if i >= rdim then
offset = 1
end
sz[i+offset] = input:size(i)
end
local gradInput = self.gradInput:view(sz)
gradInput:sum(gradOutput, self.dim)
gradInput:sum(gradOutput, rdim)
return self.gradInput
end
19 changes: 19 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3428,6 +3428,25 @@ function nntest.Replicate()

mytester:assertTensorEq(vOutput1, expected1, precision, 'Wrong tiling of data when replicating vector.')
mytester:assertTensorEq(vOutput2, expected2, precision, 'Wrong tiling of data when replicating vector.')

-- batch mode
local vector = torch.rand(4,3)

local r1 = nn.Replicate(2, 1, 1)
local r2 = nn.Replicate(2, 2, 1)

local vOutput1 = r1:forward(vector):clone()
local vOutput2 = r2:forward(vector):clone()

local expected1 = torch.zeros(4, 2, 3)
local expected2 = torch.zeros(4, 3, 2)
expected1:select(2, 1):copy(vector)
expected1:select(2, 2):copy(vector)
expected2:select(3, 1):copy(vector)
expected2:select(3, 2):copy(vector)

mytester:assertTensorEq(vOutput1, expected1, precision, 'Wrong tiling of data when replicating batch vector.')
mytester:assertTensorEq(vOutput2, expected2, precision, 'Wrong tiling of data when replicating batch vector.')
end

function nntest.BatchNormalization()
Expand Down

0 comments on commit 2ac59f9

Please sign in to comment.