-
Notifications
You must be signed in to change notification settings - Fork 1
/
glove_model.py
37 lines (27 loc) · 1.1 KB
/
glove_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
from torch.nn.init import xavier_normal_
class GloveModel(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(GloveModel, self).__init__()
self.wi = nn.Embedding(vocab_size, embedding_dim)
self.wi.weight = xavier_normal_(self.wi.weight)
self.wj = nn.Embedding(vocab_size, embedding_dim)
self.wj.weight = xavier_normal_(self.wj.weight)
self.bi = nn.Embedding(vocab_size, 1)
self.bi.weight = xavier_normal_(self.bi.weight)
self.bj = nn.Embedding(vocab_size, 1)
self.bj.weight = xavier_normal_(self.bj.weight)
def forward(self, i_indices, j_indices):
w_i = self.wi(i_indices)
w_j = self.wj(j_indices)
b_i = self.bi(i_indices).squeeze()
b_j = self.bj(j_indices).squeeze()
x = torch.sum(w_i * w_j, dim=1) + b_i + b_j
return x
if __name__ == "__main__":
# glove = GloveModel(torch.Tensor(1), torch.Tensor(1))
glove = GloveModel(100, 10)
x = glove(torch.LongTensor([1, 2]), torch.LongTensor([1]))
# x = glove(1, 1)
print(x)