Skip to content

Commit

Permalink
MixtureTable lazy initialized buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed May 21, 2015
1 parent b3f7bcc commit 1e7eec6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
35 changes: 24 additions & 11 deletions MixtureTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,21 @@ local MixtureTable, parent = torch.class('nn.MixtureTable', 'nn.Module')
function MixtureTable:__init(dim)
parent.__init(self)
self.dim = dim
self._gaterView = torch.Tensor()
self._expert = torch.Tensor()
self._expertView = torch.Tensor()
self._sum = torch.Tensor()
self.size = torch.LongStorage()
self.batchSize = 0
self.gradInput = {torch.Tensor(), {}}
self._gradInput = torch.Tensor()
self.size2 = torch.LongStorage()
self._expertView2 = torch.Tensor()
self._expert2 = torch.Tensor()
self.backwardSetup = false
self.gradInput = {}
end

function MixtureTable:updateOutput(input)
local gaterInput, expertInputs = table.unpack(input)

-- buffers
self._gaterView = self.gaterView or input[1].new()
self._expert = self._expert or input[1].new()
self._expertView = self._expertView or input[1].new()

self.dimG = 2
local batchSize = gaterInput:size(1)
if gaterInput:dim() < 2 then
Expand All @@ -43,9 +41,6 @@ function MixtureTable:updateOutput(input)
end
self.size[self.dim] = gaterInput:size(self.dimG)
self.output:resizeAs(expertInput)
if torch.type(self.gradInput[2]) ~= 'table' then
self.gradInput[2] = {}
end
self.backwardSetup = false
self.batchSize = batchSize
end
Expand Down Expand Up @@ -80,7 +75,14 @@ end

function MixtureTable:updateGradInput(input, gradOutput)
local gaterInput, expertInputs = table.unpack(input)
nn.utils.recursiveResizeAs(self.gradInput, input)
local gaterGradInput, expertGradInputs = table.unpack(self.gradInput)

-- buffers
self._sum = self._sum or input[1].new()
self._gradInput = self._gradInput or {input[1].new(), {}}
self._expertView2 = self._expertView2 or input[1].new()
self._expert2 = self._expert2 or input[1].new()

if self.table then
if not self.backwardSetup then
Expand Down Expand Up @@ -146,3 +148,14 @@ function MixtureTable:updateGradInput(input, gradOutput)

return self.gradInput
end

function MixtureTable:type(type)
self._gaterView = nil
self._expert = nil
self._expertView = nil
self._sum = nil
self._gradInput = nil
self._expert2 = nil
self._expertView2 = nil
return parent.type(self, type)
end
17 changes: 17 additions & 0 deletions utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,21 @@ function nn.utils.recursiveType(param, type_str)
return param
end

function nn.utils.recursiveResizeAs(t1,t2)
if torch.type(t2) == 'table' then
t1 = (torch.type(t1) == 'table') and t1 or {t1}
for key,_ in pairs(t2) do
t1[key], t2[key] = nn.utils.recursiveResizeAs(t1[key], t2[key])
end
elseif torch.isTensor(t2) then
t1 = torch.isTensor(t1) and t1 or t2.new()
t1:resizeAs(t2)
else
error("expecting nested tensors or tables. Got "..
torch.type(t1).." and "..torch.type(t2).." instead")
end
return t1, t2
end


table.unpack = table.unpack or unpack

0 comments on commit 1e7eec6

Please sign in to comment.