-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathmodel.py
44 lines (35 loc) · 1.27 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class CNN_Text(nn.Module):
def __init__(self, args):
super(CNN_Text, self).__init__()
D = args.get("embed_dim")
C = args.get("labels")
Ci = 1
Co = args.get("kernel_num")
Ks = args.get("kernel_sizes")
self.pre_embed = False
if args.get("embed_file"):
self.pre_embed = True
else:
V = args.get("vocab_size")
self.embed = nn.Embedding(V, D)
if args.get("static"):
self.embed.weight.requires_grad = False
self.convs = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D)) for K in Ks])
self.dropout = nn.Dropout(args.get("dropout"))
self.fc1 = nn.Linear(len(Ks) * Co, C)
def forward(self, x):
if not self.pre_embed:
x = self.embed(x)
x = x.unsqueeze(1) # (N, Ci, W, D)
x = [
F.relu(conv(x)).squeeze(3) for conv in self.convs
] # [(N, Co, W), ...]*len(Ks)
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N, Co), ...]*len(Ks)
x = torch.cat(x, 1)
x = self.dropout(x) # (N, len(Ks)*Co)
logit = self.fc1(x) # (N, C)
return logit