-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
28 lines (20 loc) · 955 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from keras import Input, Model
from keras import backend as K
from keras.layers import Conv2D, Dropout, Flatten, Dense
class ClassificationNet:
@staticmethod
def build(width_height_channel, num_classes):
(width, height, channel) = width_height_channel
if K.image_data_format() == 'channels_first':
input_shape = (channel, height, width)
else:
input_shape = (height, width, channel)
inputs = Input(shape=input_shape)
x = Conv2D(32, kernel_size=(5, 5), strides=2, activation='relu')(inputs)
x = Conv2D(64, kernel_size=(3, 3), activation='relu')(x)
x = Conv2D(128, kernel_size=(3, 3), activation='relu')(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
x = Dense(num_classes, activation='softmax', name='classification_net_output')(x)
model = Model(inputs=inputs, outputs=x, name='ClassificationNet')
return model