-
Notifications
You must be signed in to change notification settings - Fork 17
/
MaskZero.lua
83 lines (69 loc) · 2.55 KB
/
MaskZero.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
------------------------------------------------------------------------
--[[ MaskZero ]]--
-- Zeroes the elements of the state tensors
-- (output/gradOutput/input/gradInput) of the encapsulated module
-- for commensurate elements that are 1 in self.zeroMask.
-- By default only output/gradOutput are zeroMasked.
-- self.zeroMask is set with setZeroMask(zeroMask).
-- Only works in batch-mode.
-- Note that when input/gradInput are zeroMasked, it is in-place
------------------------------------------------------------------------
local MaskZero, parent = torch.class("nn.MaskZero", "nn.Decorator")
function MaskZero:__init(module, v1, maskinput, maskoutput)
parent.__init(self, module)
assert(torch.isTypeOf(module, 'nn.Module'))
self.maskinput = maskinput -- defaults to false
self.maskoutput = maskoutput == nil and true or maskoutput -- defaults to true
self.v2 = not v1
end
function MaskZero:updateOutput(input)
if self.v2 then
assert(self.zeroMask ~= nil, "MaskZero expecting zeroMask tensor or false")
else -- backwards compat
self.zeroMask = nn.utils.getZeroMaskBatch(input, self.zeroMask)
end
if self.maskinput and self.zeroMask then
nn.utils.recursiveZeroMask(input, self.zeroMask)
end
-- forward through decorated module
local output = self.modules[1]:updateOutput(input)
if self.maskoutput and self.zeroMask then
self.output = nn.utils.recursiveCopy(self.output, output)
nn.utils.recursiveZeroMask(self.output, self.zeroMask)
else
self.output = output
end
return self.output
end
function MaskZero:updateGradInput(input, gradOutput)
assert(self.zeroMask ~= nil, "MaskZero expecting zeroMask tensor or false")
if self.maskoutput and self.zeroMask then
self.gradOutput = nn.utils.recursiveCopy(self.gradOutput, gradOutput)
nn.utils.recursiveZeroMask(self.gradOutput, self.zeroMask)
gradOutput = self.gradOutput
end
self.gradInput = self.modules[1]:updateGradInput(input, gradOutput)
if self.maskinput and self.zeroMask then
nn.utils.recursiveZeroMask(self.gradInput, self.zeroMask)
end
return self.gradInput
end
function MaskZero:clearState()
self.output = nil
self.gradInput = nil
self.zeroMask = nil
return self
end
function MaskZero:type(type, ...)
self:clearState()
return parent.type(self, type, ...)
end
function MaskZero:setZeroMask(zeroMask)
if zeroMask == false then
self.zeroMask = false
else
assert(torch.isByteTensor(zeroMask))
assert(zeroMask:isContiguous())
self.zeroMask = zeroMask
end
end