-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNetworkLearner.py
125 lines (107 loc) · 5.64 KB
/
NetworkLearner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
from keras.callbacks import EarlyStopping, ModelCheckpoint
from DataParser import DataParser
class DataLearner:
def __init__(self, neural_network, data_parser, epochs=30, early_stopping=True, save_best_checkpoint=True,
use_validation_set=False, output=None):
self.neural_network = neural_network
self.epochs = epochs
self.data_parser = data_parser
self.early_stopping = early_stopping
self.save_best_checkpoint = save_best_checkpoint
self.use_validation_set = use_validation_set
self.output = output
def save(self, model):
model.save(self.get_model_save_name())
def get_model_save_name(self, checkpoint=False):
if self.output is None:
ret = self.neural_network + "_" + self.data_parser.graph_type + ".h5"
else:
ret = self.output
if not ret.endswith(".h5"):
ret += ".h5"
if checkpoint:
return "best_" + ret
return ret
def train(self):
generator = self.data_parser.get_dataset_plot_generator()
steps = len(self.data_parser.graph_files_name) // self.data_parser.batch_size
if not self.use_validation_set:
generator_val = None
val_steps = None
else:
generator_val = self.data_parser.get_dataset_plot_val_generator()
val_steps = len(self.data_parser.val_graph_files_name) // self.data_parser.batch_size
cb = []
if self.early_stopping and self.use_validation_set:
cb.append(EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=5))
if self.save_best_checkpoint:
cb.append(ModelCheckpoint(self.get_model_save_name(checkpoint=True), monitor='loss', mode='min',
save_best_only=True))
if self.neural_network == "leonet":
from models import LeoNet
model = LeoNet.leonet_model((1, 224, 224, 3))
history = model.fit_generator(generator=generator, validation_data=generator_val,
validation_steps=val_steps, epochs=self.epochs, steps_per_epoch=steps,
callbacks=cb)
self.print_history(history)
return model
elif self.neural_network == "leonetv2":
from models import LeoNetV2
model = LeoNetV2.LeoNetV2_model((1, 224, 224, 3))
history = model.fit_generator(generator=generator, validation_data=generator_val,
validation_steps=val_steps,
epochs=self.epochs, steps_per_epoch=steps, callbacks=cb)
self.print_history(history)
return model
def print_history(self, history):
import matplotlib.pyplot as plt
name = self.get_model_save_name()
if self.output is not None:
name = self.get_model_save_name()
plt.plot(history.history['acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
if self.use_validation_set:
plt.plot(history.history['val_loss'])
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig(name + "_acc.png")
plt.close()
# summarize history for loss
plt.plot(history.history['loss'])
if self.use_validation_set:
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig(name + "_loss.png")
plt.close()
def main(neural_network, type_graph, folders, epochs, batch_size, val_percentage, output, early_stopping):
data_parser = DataParser("training", folders, type_graph, batch_size=batch_size, val_percentage=val_percentage)
learner = DataLearner(neural_network, data_parser, epochs=epochs,
use_validation_set=False if val_percentage == 0.0 else True, output=output, early_stopping=early_stopping)
model = learner.train()
learner.save(model)
return
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Customization options for the network learner")
parser.add_argument("graph_type", nargs='?', default="mfcc", help='Set type of graph (i.e: mfcc/melspectrogram/melspectrogram-energy/spectrogram)')
parser.add_argument("neural_network", nargs='?', default="leonetv2", help='Set the network you want to train (i.e: LeoNet/LeoNetV2)')
parser.add_argument("folders", nargs='?', default="warblrb10k,ff1010bird",
help='Set of folders that will be used as the source of the graphs')
parser.add_argument("batch_size", nargs='?', type=int, default=20,
help='Batch size of the files used to train the model')
parser.add_argument("validation_percentage", nargs='?', type=float, default=0.0,
help='Batch size of the files used to train the model')
parser.add_argument("epochs", nargs='?', type=int, default=30,
help='Number of epochs')
parser.add_argument("output", nargs='?', default="leonetv2_mfcc_ff_ww",
help='Output filename')
parser.add_argument("early_stopping", nargs='?', default="False", type=bool,
help='Use early stopping when training the model')
args = parser.parse_args()
print(args)
main(str(args.neural_network).lower(), str(args.graph_type).lower(), [item.strip() for item in args.folders.strip().split(',')], args.epochs, args.batch_size, args.validation_percentage,
args.output, False)