forked from tslgithub/image_class
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Build_model.py
133 lines (115 loc) · 6.06 KB
/
Build_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python
#-*- coding:utf-8 -*-
# author:"tsl"
# email:"[email protected]"
# datetime:19-1-17 下午3:07
# software: PyCharm
from __future__ import print_function
import keras
from MODEL import MODEL,ResnetBuilder
import sys
sys.setrecursionlimit(10000)
from keras import backend as K
# import densenet #取消densenet模型
class Build_model(object):
def __init__(self,config):
self.train_data_path = config.train_data_path
self.checkpoints = config.checkpoints
self.normal_size = config.normal_size
self.channles = config.channles
self.epochs = config.epochs
self.batch_size = config.batch_size
self.classes = config.classes
self.model_name = config.model_name
self.lr = config.lr
self.config = config
self.default_optimizers = config.default_optimizers
self.data_augmentation = config.data_augmentation
self.rat = config.rat
self.cut = config.cut
def model_confirm(self,choosed_model):
if choosed_model == 'VGG16':
model = keras.applications.VGG16(include_top=True,
weights=None,
input_tensor=None,
input_shape=(self.normal_size,self.normal_size,self.channles),
pooling='max',
classes=self.classes)
elif choosed_model == 'VGG19':
model = keras.applications.VGG19(include_top=True,
weights=None,
input_tensor=None,
input_shape=(self.normal_size,self.normal_size,self.channles),
pooling='max',
classes=self.classes)
elif choosed_model == 'ResNet50':
model = keras.applications.ResNet50(include_top=True,
weights=None,
input_tensor=None,
input_shape=(self.normal_size,self.normal_size,self.channles),
pooling='max',
classes=self.classes)
elif choosed_model == 'InceptionV3':
model = keras.applications.InceptionV3(include_top=True,
weights=None,
input_tensor=None,
input_shape=(self.normal_size,self.normal_size,self.channles),
pooling='max',
classes=self.classes)
elif choosed_model == 'Xception':
model = keras.applications.Xception(include_top=True,
weights=None,
input_tensor=None,
input_shape=(self.normal_size,self.normal_size,self.channles),
pooling='max',
classes=self.classes)
elif choosed_model == 'MobileNet':
model = keras.applications.MobileNet(include_top=True,
weights=None,
input_tensor=None,
input_shape=(self.normal_size,self.normal_size,self.channles),
pooling='max',
classes=self.classes)
elif choosed_model == 'DenseNet':
depth = 40
nb_dense_block = 3
growth_rate = 12
nb_filter = 12
bottleneck = False
reduction = 0.0
dropout_rate = 0.0
img_dim = (self.channles, self.normal_size) if K.image_dim_ordering() == "th" else (
self.normal_size, self.normal_size, self.channles)
model = densenet.DenseNet(img_dim, classes=self.classes, depth=depth, nb_dense_block=nb_dense_block,
growth_rate=growth_rate, nb_filter=nb_filter, dropout_rate=dropout_rate,
bottleneck=bottleneck, reduction=reduction, weights=None)
elif choosed_model == 'AlexNet':
model = MODEL(self.config).AlexNet()
elif choosed_model == 'LeNet':
model = MODEL(self.config).LeNet()
elif choosed_model == 'ZF_Net':
model = MODEL(self.config).ZF_Net()
elif choosed_model == 'ResNet18':
model = ResnetBuilder().build_resnet18(self.config)
elif choosed_model == 'ResNet34':
model = ResnetBuilder().build_resnet34(self.config)
elif choosed_model == 'ResNet101':
model = ResnetBuilder().build_resnet101(self.config)
elif choosed_model == 'ResNet152':
model = ResnetBuilder().build_resnet152(self.config)
elif choosed_model =='mnist_net':
model = MODEL(self.config).mnist_net()
elif choosed_model == 'VGG16_TSL':
model = MODEL(self.config).VGG16_TSL()
return model
def model_compile(self,model):
if self.default_optimizers:
adam = keras.optimizers.Adam(lr=self.lr, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0)
model.compile(loss="categorical_crossentropy", optimizer=adam, metrics=["accuracy"]) # compile之后才会更新权重和模型
else:
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
return model
def build_model(self):
model = self.model_confirm(self.model_name)
model = self.model_compile(model)
return model