-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
81 lines (70 loc) · 2.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from tensorflow.keras import Model
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Constants
BATCH_SIZE = 32
MAX_EPOCHS = 100
# Generator
generator = ImageDataGenerator(
data_format='channels_last',
rescale=1. / 255
)
train_batches = generator.flow_from_directory(
batch_size=BATCH_SIZE,
directory='dataset/cell_images_train',
target_size=[96, 96],
class_mode='categorical'
)
val_batches = generator.flow_from_directory(
batch_size=BATCH_SIZE,
directory='dataset/cell_images_validation',
target_size=[96, 96],
class_mode='categorical'
)
model = MobileNetV2(input_shape=(96, 96, 3),
weights='imagenet', include_top=False, classes=2)
flat = Flatten()(model.output)
output = Dense(2, activation='softmax')(flat)
model = Model(inputs=model.input, outputs=output)
# Prepare model to run
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Callback to save weights, based on val_acc
model_checkpoint_callback = ModelCheckpoint(
'./checkpoints/{epoch:02d}_{val_acc:.4f}.h5',
save_weights_only=False,
verbose=1,
monitor='val_acc',
save_best_only=True,
mode='max'
)
# Callbackto plot data on TensorBoard
tensorboard_callback = TensorBoard(
log_dir='./logs/malaria',
histogram_freq=0,
batch_size=BATCH_SIZE
)
# Callback to reduce learning rate after plateaus
reduce_lr_callback = ReduceLROnPlateau(
monitor='val_acc',
factor=0.5,
patience=6,
min_lr=1e-6
)
early_stopping_callback = EarlyStopping(
monitor='val_acc',
patience=30,
mode='max',
)
# Starts training the model
model.fit_generator(train_batches,
epochs=MAX_EPOCHS,
verbose=1,
validation_data=val_batches,
callbacks=[model_checkpoint_callback, tensorboard_callback,
reduce_lr_callback, early_stopping_callback]
)