forked from clementfarabet/lua---nnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathZeroGrad.lua
28 lines (25 loc) · 856 Bytes
/
ZeroGrad.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
local ZeroGrad, parent = torch.class("nn.ZeroGrad", "nn.Module")
local function recursiveZero(t1,t2)
if torch.type(t2) == 'table' then
t1 = (torch.type(t1) == 'table') and t1 or {t1}
for key,_ in pairs(t2) do
t1[key], t2[key] = recursiveZero(t1[key], t2[key])
end
elseif torch.isTensor(t2) then
t1 = t1 or t2.new()
t1:resizeAs(t2):zero()
else
error("expecting nested tensors or tables. Got "..
torch.type(t1).." and "..torch.type(t2).." instead")
end
return t1, t2
end
function ZeroGrad:updateOutput(input)
self.output:set(input)
return self.output
end
-- the gradient is simply zeroed.
-- useful when you don't want to backpropgate through certain paths.
function ZeroGrad:updateGradInput(input, gradOutput)
self.gradInput = recursiveZero(self.gradInput, gradOutput)
end