-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention.lua
68 lines (49 loc) · 1.8 KB
/
attention.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
attention = {}
require('nn.BernoulliLayer')
attention.ABLATE_INPUT = false
attention.ABLATE_STATE = false
attention.ABLATE_SURPRISAL = false
if string.match(params.ablation, 'i') then
attention.ABLATE_INPUT = true
end
if string.match(params.ablation, 'r') then
attention.ABLATE_STATE = true
end
if string.match(params.ablation, 's') then
attention.ABLATE_SURPRISAL = true
end
print("ABLATION INP STATE SURP")
print(attention.ABLATE_INPUT)
print(attention.ABLATE_STATE)
print(attention.ABLATE_SURPRISAL)
function attention.createAttentionNetwork()
assert(params.TASK == 'combined')
if USE_BASELINE_NETWORK then
return attention.createAttentionNetworkEmbeddingsSurprisalWithBaseline()
end
local x = nn.Identity()()
local xemb = nn.BlockGradientLayer(params.batch_size, params.embeddings_dimensionality)(nn.LookupTable(params.vocab_size,params.embeddings_dimensionality)(x))
local y = nn.Identity()()
local surprisal = nn.Identity()()
-- ABLATION OF INPUT
if attention.ABLATE_INPUT then
xemb = nn.MulConstant(0)(xemb)
end
local x2h = nn.Linear(params.embeddings_dimensionality, params.rnn_size)(xemb)
local y2h = nn.Linear(params.rnn_size, params.rnn_size)(y)
-- ABLATION OF STATE
if attention.ABLATE_STATE then
y2h = nn.MulConstant(0)(y2h)
end
local z2h = nn.Linear(1, params.rnn_size)(surprisal)
-- ABLATION OF SURPRISAL
if attention.ABLATE_SURPRISAL then
z2h = nn.MulConstant(0)(z2h)
end
local hidden = nn.Sigmoid()(nn.CAddTable()({x2h, y2h, z2h}))
local attention = (nn.Sigmoid()(nn.Linear(params.rnn_size, 1)(hidden)))
local module = nn.gModule({x, y, surprisal},
{attention})
module:getParameters():uniform(-params.init_weight, params.init_weight)
return transfer_data(module)
end