-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsrgnn.py
143 lines (121 loc) · 5.42 KB
/
srgnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# @Time : 2022/3/7
# @Author : Yupeng Hou
# @Email : [email protected]
r"""
SRGNN
################################################
Reference:
Shu Wu et al. "Session-based Recommendation with Graph Neural Networks." in AAAI 2019.
Reference code:
https://github.com/CRIPAC-DIG/SR-GNN
"""
import numpy as np
import torch
from torch import nn
from recbole.model.loss import BPRLoss
from recbole.model.abstract_recommender import SequentialRecommender
from recbole_gnn.model.layers import SRGNNCell
class SRGNN(SequentialRecommender):
r"""SRGNN regards the conversation history as a directed graph.
In addition to considering the connection between the item and the adjacent item,
it also considers the connection with other interactive items.
Such as: A example of a session sequence(eg:item1, item2, item3, item2, item4) and the connection matrix A
Outgoing edges:
=== ===== ===== ===== =====
\ 1 2 3 4
=== ===== ===== ===== =====
1 0 1 0 0
2 0 0 1/2 1/2
3 0 1 0 0
4 0 0 0 0
=== ===== ===== ===== =====
Incoming edges:
=== ===== ===== ===== =====
\ 1 2 3 4
=== ===== ===== ===== =====
1 0 0 0 0
2 1/2 0 1/2 0
3 0 1 0 0
4 0 1 0 0
=== ===== ===== ===== =====
"""
def __init__(self, config, dataset):
super(SRGNN, self).__init__(config, dataset)
# load parameters info
self.embedding_size = config['embedding_size']
self.step = config['step']
self.device = config['device']
self.loss_type = config['loss_type']
# item embedding
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
# define layers and loss
self.gnncell = SRGNNCell(self.embedding_size)
self.linear_one = nn.Linear(self.embedding_size, self.embedding_size)
self.linear_two = nn.Linear(self.embedding_size, self.embedding_size)
self.linear_three = nn.Linear(self.embedding_size, 1, bias=False)
self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size)
if self.loss_type == 'BPR':
self.loss_fct = BPRLoss()
elif self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
# parameters initialization
self._reset_parameters()
def _reset_parameters(self):
stdv = 1.0 / np.sqrt(self.embedding_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, x, edge_index, alias_inputs, item_seq_len):
mask = alias_inputs.gt(0)
hidden = self.item_embedding(x)
for i in range(self.step):
hidden = self.gnncell(hidden, edge_index)
seq_hidden = hidden[alias_inputs]
# fetch the last hidden state of last timestamp
ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1))
q2 = self.linear_two(seq_hidden)
alpha = self.linear_three(torch.sigmoid(q1 + q2))
a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1)
seq_output = self.linear_transform(torch.cat([a, ht], dim=1))
return seq_output
def calculate_loss(self, interaction):
x = interaction['x']
edge_index = interaction['edge_index']
alias_inputs = interaction['alias_inputs']
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == 'BPR':
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
return loss
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
return loss
def predict(self, interaction):
test_item = interaction[self.ITEM_ID]
x = interaction['x']
edge_index = interaction['edge_index']
alias_inputs = interaction['alias_inputs']
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores
def full_sort_predict(self, interaction):
x = interaction['x']
edge_index = interaction['edge_index']
alias_inputs = interaction['alias_inputs']
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
test_items_emb = self.item_embedding.weight
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items]
return scores