-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
49 lines (35 loc) · 1.34 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import json
import torchvision.models as models
from torchvision.datasets.utils import download_url
class Vgg16:
def __init__(self):
self.model = models.vgg16(pretrained=True)
self.name = "vgg"
self.ce_layer_name = 'features_29'
download_url("https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json", ".",
"data/imagenet_class_index.json")
with open("data/imagenet_class_index.json", "r") as h:
self.labels = json.load(h)
def eval(self):
self.model.eval()
def predict(self, img):
predictions = self.model(img)
return predictions
def __call__(self, x):
return self.predict(x)
class AlexNet:
def __init__(self):
self.model = models.alexnet(pretrained=True)
self.name = "alexnet"
self.ce_layer_name = "features_11"
download_url("https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json", ".",
"data/imagenet_class_index.json")
with open("data/imagenet_class_index.json", "r") as h:
self.labels = json.load(h)
def eval(self):
self.model.eval()
def predict(self, img):
predictions = self.model(img)
return predictions
def __call__(self, x):
return self.predict(x)