forked from gimpong/MM23-MISSRec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_utils.py
73 lines (61 loc) · 3 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.loss import BPRLoss
class Transformer(SequentialRecommender):
def __init__(self, config, dataset):
super(Transformer, self).__init__(config, dataset)
# load parameters info
self.n_layers = config['n_layers']
self.n_heads = config['n_heads']
self.hidden_size = config['hidden_size'] # same as embedding_size
self.inner_size = config['inner_size'] # the dimensionality in feed-forward layer
self.hidden_dropout_prob = config['hidden_dropout_prob']
self.attn_dropout_prob = config['attn_dropout_prob']
self.hidden_act = config['hidden_act']
self.layer_norm_eps = config['layer_norm_eps']
self.initializer_range = config['initializer_range']
self.loss_type = config['loss_type']
# define layers and loss
self.item_embedding = nn.Embedding(self.n_items, self.hidden_size, padding_idx=0)
self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size)
assert self.n_layers % 2 == 0, "'n_layers' must be an even number"
self.trm_model = nn.Transformer(
d_model=self.hidden_size,
nhead=self.n_heads,
num_encoder_layers=self.n_layers // 2,
num_decoder_layers=self.n_layers // 2,
dim_feedforward=self.inner_size,
dropout=self.hidden_dropout_prob,
activation=self.hidden_act,
layer_norm_eps=self.layer_norm_eps,
batch_first=True,
norm_first=False,
)
self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.dropout = nn.Dropout(self.hidden_dropout_prob)
if self.loss_type == 'BPR':
self.loss_fct = BPRLoss()
elif self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
# parameters initialization
self.apply(self._init_weights)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()