-
Notifications
You must be signed in to change notification settings - Fork 15
/
models.py
62 lines (52 loc) · 1.94 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn as nn
from torchvision.models import resnet34
class FaceNetModel(nn.Module):
def __init__(self, embedding_size, num_classes, pretrained=False):
super(FaceNetModel, self).__init__()
self.model = resnet34(pretrained)
self.embedding_size = embedding_size
self.output_conv = self._get_output_conv((1, 3, 224, 224))
self.model.fc = nn.Linear(self.output_conv, self.embedding_size)
self.model.classifier = nn.Linear(self.embedding_size, num_classes)
def l2_norm(self, input):
input_size = input.size()
buffer = torch.pow(input, 2)
normp = torch.sum(buffer, 1).add_(1e-10)
norm = torch.sqrt(normp)
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
output = _output.view(input_size)
return output
def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = x.view(x.size(0), -1)
x = self.model.fc(x)
self.features = self.l2_norm(x)
# Multiply by alpha = 10 as suggested in https://arxiv.org/pdf/1703.09507.pdf
alpha = 10
self.features = self.features*alpha
return self.features
def forward_classifier(self, x):
features = self.forward(x)
res = self.model.classifier(features)
return res
def _get_output_conv(self, shape):
x = torch.rand(shape)
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = x.view(x.size(0), -1)
output_conv_shape = x.size(1)
return output_conv_shape