-
Notifications
You must be signed in to change notification settings - Fork 2
/
epi_irv2_augment.py
64 lines (56 loc) · 2.25 KB
/
epi_irv2_augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from tensorflow.keras import preprocessing as preproc
from tensorflow.keras.applications import InceptionResNetV2
from tensorflow.keras.models import Sequential as seq, load_model as load
from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dense
from tensorflow.keras.optimizers import Adadelta
from tensorflow.keras.callbacks import ModelCheckpoint, TerminateOnNaN
from tensorflow.keras.metrics import Accuracy, AUC, Precision, Recall, SpecificityAtSensitivity
from pandas import DataFrame as df
datagen = preproc.image.ImageDataGenerator(
validation_split=.18,
rescale=1./255,
brightness_range=[25.5, 65.5],
#shear_range=0.3,
zoom_range=0.2,
#horizontal_flip=True,
)
train = datagen.flow_from_directory(
directory="data/plot_epi/train/",
class_mode="categorical",
color_mode="rgb",
target_size=(299, 299),
shuffle=True,
interpolation="bilinear",
seed=42,
subset="training",
)
val = datagen.flow_from_directory(
directory="data/plot_epi/train/",
class_mode="categorical",
color_mode="rgb",
target_size=(299, 299),
shuffle=False,
interpolation="bilinear",
seed=42,
subset="validation",
)
epi_InceptionResNetV2_model = seq([
Input(shape=(299, 299, 3)),
InceptionResNetV2(include_top=False, weights="imagenet", input_shape=(299, 299, 3)),
GlobalAveragePooling2D(),
Dense(2, activation="softmax")
], name="EPI_InceptionResNetV2")
epi_InceptionResNetV2_model.compile(
loss="categorical_crossentropy",
optimizer=Adadelta(learning_rate=1e-2),
metrics=["acc", Precision(.51), Recall(.51), SpecificityAtSensitivity(.5), AUC()]
)
callbacks = [
ModelCheckpoint(filepath="ckpt/checkpoint-augment-inceptionresnetv2-epi-{epoch:02d}-{val_acc:.3f}.h5", monitor="val_acc", save_best_only=True, mode="max"),
TerminateOnNaN()
]
epi_InceptionResNetV2_model_result = epi_InceptionResNetV2_model.fit(
x=train, validation_data=val, epochs=30, callbacks=callbacks)
epi_InceptionResNetV2_model.save("model/augment_epi_InceptionResNetV2_model.h5")
epi_InceptionResNetV2_model.save_weights("model/augment_epi_InceptionResNetV2_weights.h5")
df.from_dict(epi_InceptionResNetV2_model_result.history).to_csv('result/augment_epi_InceptionResNetV2_model_result.csv', index=False)