-
Notifications
You must be signed in to change notification settings - Fork 259
/
Copy pathmodel.py
120 lines (108 loc) · 7.12 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torch import nn
from models import resnet, pre_act_resnet, wide_resnet, resnext, densenet
def generate_model(opt):
assert opt.mode in ['score', 'feature']
if opt.mode == 'score':
last_fc = True
elif opt.mode == 'feature':
last_fc = False
assert opt.model_name in ['resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet']
if opt.model_name == 'resnet':
assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
if opt.model_depth == 10:
model = resnet.resnet10(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 18:
model = resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 34:
model = resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 50:
model = resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 101:
model = resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 152:
model = resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 200:
model = resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_name == 'wideresnet':
assert opt.model_depth in [50]
if opt.model_depth == 50:
model = wide_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, k=opt.wide_resnet_k,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_name == 'resnext':
assert opt.model_depth in [50, 101, 152]
if opt.model_depth == 50:
model = resnext.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 101:
model = resnext.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 152:
model = resnext.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_name == 'preresnet':
assert opt.model_depth in [18, 34, 50, 101, 152, 200]
if opt.model_depth == 18:
model = pre_act_resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 34:
model = pre_act_resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 50:
model = pre_act_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 101:
model = pre_act_resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 152:
model = pre_act_resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 200:
model = pre_act_resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_name == 'densenet':
assert opt.model_depth in [121, 169, 201, 264]
if opt.model_depth == 121:
model = densenet.densenet121(num_classes=opt.n_classes,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 169:
model = densenet.densenet169(num_classes=opt.n_classes,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 201:
model = densenet.densenet201(num_classes=opt.n_classes,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 264:
model = densenet.densenet264(num_classes=opt.n_classes,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
if not opt.no_cuda:
model = model.cuda()
model = nn.DataParallel(model, device_ids=None)
return model