From 2ac59f930bd3333bbb1c47cf5fe77c3970e24f25 Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Thu, 14 May 2015 17:23:30 -0400 Subject: [PATCH] Replicate batchMode --- Replicate.lua | 21 +++++++++++++-------- test.lua | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/Replicate.lua b/Replicate.lua index 8e311039b..f66f2d5dd 100644 --- a/Replicate.lua +++ b/Replicate.lua @@ -1,9 +1,10 @@ local Replicate, parent = torch.class('nn.Replicate','nn.Module') -function Replicate:__init(nf, dim) +function Replicate:__init(nf, dim, ndim) parent.__init(self) self.nfeatures = nf self.dim = dim or 1 + self.ndim = ndim assert(self.dim > 0, "Can only replicate across positive integer dimensions.") end @@ -13,20 +14,22 @@ function Replicate:updateOutput(input) self.dim <= input:dim()+1, "Not enough input dimensions to replicate along dimension " .. tostring(self.dim) .. ".") + local batchOffset = self.ndim and input:dim() > self.ndim and 1 or 0 + local rdim = self.dim + batchOffset local sz = torch.LongStorage(input:dim()+1) - sz[self.dim] = self.nfeatures + sz[rdim] = self.nfeatures for i = 1,input:dim() do local offset = 0 - if i >= self.dim then + if i >= rdim then offset = 1 end sz[i+offset] = input:size(i) end local st = torch.LongStorage(input:dim()+1) - st[self.dim] = 0 + st[rdim] = 0 for i = 1,input:dim() do local offset = 0 - if i >= self.dim then + if i >= rdim then offset = 1 end st[i+offset] = input:stride(i) @@ -37,16 +40,18 @@ end function Replicate:updateGradInput(input, gradOutput) self.gradInput:resizeAs(input):zero() + local batchOffset = self.ndim and input:dim() > self.ndim and 1 or 0 + local rdim = self.dim + batchOffset local sz = torch.LongStorage(input:dim()+1) - sz[self.dim] = 1 + sz[rdim] = 1 for i = 1,input:dim() do local offset = 0 - if i >= self.dim then + if i >= rdim then offset = 1 end sz[i+offset] = input:size(i) end local gradInput = self.gradInput:view(sz) - gradInput:sum(gradOutput, self.dim) + gradInput:sum(gradOutput, rdim) return self.gradInput end diff --git a/test.lua b/test.lua index d82e470fe..5f408995b 100644 --- a/test.lua +++ b/test.lua @@ -3428,6 +3428,25 @@ function nntest.Replicate() mytester:assertTensorEq(vOutput1, expected1, precision, 'Wrong tiling of data when replicating vector.') mytester:assertTensorEq(vOutput2, expected2, precision, 'Wrong tiling of data when replicating vector.') + + -- batch mode + local vector = torch.rand(4,3) + + local r1 = nn.Replicate(2, 1, 1) + local r2 = nn.Replicate(2, 2, 1) + + local vOutput1 = r1:forward(vector):clone() + local vOutput2 = r2:forward(vector):clone() + + local expected1 = torch.zeros(4, 2, 3) + local expected2 = torch.zeros(4, 3, 2) + expected1:select(2, 1):copy(vector) + expected1:select(2, 2):copy(vector) + expected2:select(3, 1):copy(vector) + expected2:select(3, 2):copy(vector) + + mytester:assertTensorEq(vOutput1, expected1, precision, 'Wrong tiling of data when replicating batch vector.') + mytester:assertTensorEq(vOutput2, expected2, precision, 'Wrong tiling of data when replicating batch vector.') end function nntest.BatchNormalization()