-
Notifications
You must be signed in to change notification settings - Fork 14
/
model.py
executable file
·162 lines (143 loc) · 6.86 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import datetime
import math
import numpy as np
import torch
from torch import nn
from torch.nn import Module, Parameter
import torch.nn.functional as F
class GNN(Module):
def __init__(self, hidden_size, step=1):
super(GNN, self).__init__()
self.step = step
self.hidden_size = hidden_size
self.input_size = hidden_size * 2
self.gate_size = 3 * hidden_size
self.w_ih = Parameter(torch.Tensor(self.gate_size, self.input_size))
self.w_hh = Parameter(torch.Tensor(self.gate_size, self.hidden_size))
self.b_ih = Parameter(torch.Tensor(self.gate_size))
self.b_hh = Parameter(torch.Tensor(self.gate_size))
self.b_iah = Parameter(torch.Tensor(self.hidden_size))
self.b_oah = Parameter(torch.Tensor(self.hidden_size))
self.linear_edge_in = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.linear_edge_out = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.linear_edge_f = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
def GNNCell(self, A, hidden):
input_in = torch.matmul(A[:, :, :A.shape[1]], self.linear_edge_in(hidden)) + self.b_iah
input_out = torch.matmul(A[:, :, A.shape[1]: 2 * A.shape[1]], self.linear_edge_out(hidden)) + self.b_oah
inputs = torch.cat([input_in, input_out], 2)
gi = F.linear(inputs, self.w_ih, self.b_ih)
gh = F.linear(hidden, self.w_hh, self.b_hh)
i_r, i_i, i_n = gi.chunk(3, 2)
h_r, h_i, h_n = gh.chunk(3, 2)
resetgate = torch.sigmoid(i_r + h_r)
inputgate = torch.sigmoid(i_i + h_i)
newgate = torch.tanh(i_n + resetgate * h_n)
hy = newgate + inputgate * (hidden - newgate)
return hy
def forward(self, A, hidden):
for i in range(self.step):
hidden = self.GNNCell(A, hidden)
return hidden
class SessionGraph(Module):
def __init__(self, opt, n_node):
super(SessionGraph, self).__init__()
self.hidden_size = opt.hiddenSize
self.n_node = n_node
self.batch_size = opt.batchSize
self.nonhybrid = opt.nonhybrid
self.embedding = nn.Embedding(self.n_node, self.hidden_size)
self.gnn = GNN(self.hidden_size, step=opt.step)
self.linear_one = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.linear_two = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.linear_three = nn.Linear(self.hidden_size, 1, bias=False)
self.linear_transform = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=True)
self.linear_t = nn.Linear(self.hidden_size, self.hidden_size, bias=False) #target attention
self.loss_function = nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(self.parameters(), lr=opt.lr, weight_decay=opt.l2)
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=opt.lr_dc_step, gamma=opt.lr_dc)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def compute_scores(self, hidden, mask):
ht = hidden[torch.arange(mask.shape[0]).long(), torch.sum(mask, 1) - 1] # batch_size x latent_size
q1 = self.linear_one(ht).view(ht.shape[0], 1, ht.shape[1]) # batch_size x 1 x latent_size
q2 = self.linear_two(hidden) # batch_size x seq_length x latent_size
alpha = self.linear_three(torch.sigmoid(q1 + q2)) # (b,s,1)
# alpha = torch.sigmoid(alpha) # B,S,1
alpha = F.softmax(alpha, 1) # B,S,1
a = torch.sum(alpha * hidden * mask.view(mask.shape[0], -1, 1).float(), 1) # (b,d)
if not self.nonhybrid:
a = self.linear_transform(torch.cat([a, ht], 1))
b = self.embedding.weight[1:] # n_nodes x latent_size
# target attention: sigmoid(hidden M b)
# mask # batch_size x seq_length
hidden = hidden * mask.view(mask.shape[0], -1, 1).float() # batch_size x seq_length x latent_size
qt = self.linear_t(hidden) # batch_size x seq_length x latent_size
# beta = torch.sigmoid(b @ qt.transpose(1,2)) # batch_size x n_nodes x seq_length
beta = F.softmax(b @ qt.transpose(1,2), -1) # batch_size x n_nodes x seq_length
target = beta @ hidden # batch_size x n_nodes x latent_size
a = a.view(ht.shape[0], 1, ht.shape[1]) # b,1,d
a = a + target # b,n,d
scores = torch.sum(a * b, -1) # b,n
# scores = torch.matmul(a, b.transpose(1, 0))
return scores
def forward(self, inputs, A):
hidden = self.embedding(inputs)
hidden = self.gnn(A, hidden)
return hidden
def trans_to_cuda(variable):
if torch.cuda.is_available():
return variable.cuda()
else:
return variable
def trans_to_cpu(variable):
if torch.cuda.is_available():
return variable.cpu()
else:
return variable
def forward(model, i, data):
alias_inputs, A, items, mask, targets = data.get_slice(i)
alias_inputs = trans_to_cuda(torch.Tensor(alias_inputs).long())
items = trans_to_cuda(torch.Tensor(items).long())
A = trans_to_cuda(torch.Tensor(A).float())
mask = trans_to_cuda(torch.Tensor(mask).long())
hidden = model(items, A)
get = lambda i: hidden[i][alias_inputs[i]]
seq_hidden = torch.stack([get(i) for i in torch.arange(len(alias_inputs)).long()])
return targets, model.compute_scores(seq_hidden, mask)
def train_test(model, train_data, test_data):
model.scheduler.step()
print('start training: ', datetime.datetime.now())
model.train()
total_loss = 0.0
slices = train_data.generate_batch(model.batch_size)
for i, j in zip(slices, np.arange(len(slices))):
model.optimizer.zero_grad()
targets, scores = forward(model, i, train_data)
targets = trans_to_cuda(torch.Tensor(targets).long())
loss = model.loss_function(scores, targets - 1)
loss.backward()
model.optimizer.step()
total_loss += loss.item()
if j % int(len(slices) / 5 + 1) == 0:
print('[%d/%d] Loss: %.4f' % (j, len(slices), loss.item()))
print('\tLoss:\t%.3f' % total_loss)
print('start predicting: ', datetime.datetime.now())
model.eval()
hit, mrr = [], []
slices = test_data.generate_batch(model.batch_size)
for i in slices:
targets, scores = forward(model, i, test_data)
sub_scores = scores.topk(20)[1]
sub_scores = trans_to_cpu(sub_scores).detach().numpy()
for score, target, mask in zip(sub_scores, targets, test_data.mask):
hit.append(np.isin(target - 1, score))
if len(np.where(score == target - 1)[0]) == 0:
mrr.append(0)
else:
mrr.append(1 / (np.where(score == target - 1)[0][0] + 1))
hit = np.mean(hit) * 100
mrr = np.mean(mrr) * 100
return hit, mrr