forked from SchattenGenie/mlhep2020_pid_sparse
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sparse_model.py
22 lines (20 loc) · 857 Bytes
/
sparse_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter
from sparse_vd import LinearSVDO
class SparseNet(nn.Module):
def __init__(self, input_dim, device, threshold=1.):
super(SparseNet, self).__init__()
self.fc1 = LinearSVDO(input_dim, 100, threshold=threshold, device=device)
self.fc2 = LinearSVDO(100, 100, threshold=threshold, device=device)
self.fc3 = LinearSVDO(100, 6, threshold=threshold, device=device)
# verify that your model have threshold _Parameter_!
# and that requires_grad=False
self.threshold = Parameter(torch.as_tensor(threshold), requires_grad=False)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.log_softmax(self.fc3(x), dim=1)
return x