-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
39 lines (33 loc) · 1.22 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn
import torchvision
class ConvNet(nn.Module):
def __init__(self,model,num_classes):
super(ConvNet,self).__init__()
self.base_model = nn.Sequential(*list(model.children())[:-1]) # model excluding last FC layer
self.linear1 = nn.Linear(in_features=62720,out_features=4096) # flattened dimension of mobilenet_v2
self.linear2 = nn.Linear(in_features=4096,out_features=256)
self.linear3 = nn.Linear(in_features=256,out_features=num_classes)
self.relu = nn.LeakyReLU()
def forward(self,x):
x = self.base_model(x)
x = torch.flatten(x,1)
# print(x.shape)
x = self.linear1(x)
x = self.relu(x)
lin = self.linear2(x)
x = self.relu(lin)
x = self.linear3(x)
return lin, x
def get_model(device, num_classes):
model = torchvision.models.mobilenet_v2(pretrained=True)
model = model.to(device)
model = ConvNet(model, num_classes)
model = model.to(device)
return model
def test_model(model, device):
# testing if model has any ambiguity
vec = torch.ones((4,3,224,224),dtype=torch.float32)
vec = vec.to(device)
feature, out = model(vec)
print(feature.shape, out.shape)