-
Notifications
You must be signed in to change notification settings - Fork 37
/
models.py
93 lines (73 loc) · 2.31 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from torch import nn
from torch.autograd import Variable
class ReptileModel(nn.Module):
def __init__(self):
nn.Module.__init__(self)
def point_grad_to(self, target):
'''
Set .grad attribute of each parameter to be proportional
to the difference between self and target
'''
for p, target_p in zip(self.parameters(), target.parameters()):
if p.grad is None:
if self.is_cuda():
p.grad = Variable(torch.zeros(p.size())).cuda()
else:
p.grad = Variable(torch.zeros(p.size()))
p.grad.data.zero_() # not sure this is required
p.grad.data.add_(p.data - target_p.data)
def is_cuda(self):
return next(self.parameters()).is_cuda
class OmniglotModel(ReptileModel):
"""
A model for Omniglot classification.
"""
def __init__(self, num_classes):
ReptileModel.__init__(self)
self.num_classes = num_classes
self.conv = nn.Sequential(
# 28 x 28 - 1
nn.Conv2d(1, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 14 x 14 - 64
nn.Conv2d(64, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 7 x 7 - 64
nn.Conv2d(64, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 4 x 4 - 64
nn.Conv2d(64, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 2 x 2 - 64
)
self.classifier = nn.Sequential(
# 2 x 2 x 64 = 256
nn.Linear(256, num_classes),
nn.LogSoftmax(1)
)
def forward(self, x):
out = x.view(-1, 1, 28, 28)
out = self.conv(out)
out = out.view(len(out), -1)
out = self.classifier(out)
return out
def predict(self, prob):
__, argmax = prob.max(1)
return argmax
def clone(self):
clone = OmniglotModel(self.num_classes)
clone.load_state_dict(self.state_dict())
if self.is_cuda():
clone.cuda()
return clone
if __name__ == '__main__':
model = OmniglotModel(20)
x = Variable(torch.zeros(5, 28*28))
y = model(x)
print 'x', x.size()
print 'y', y.size()