forked from dhlab-epfl/dhSegment
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
135 lines (114 loc) · 6.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import tensorflow as tf
# Tensorflow logging level
from logging import WARNING # import DEBUG, INFO, ERROR for more/less verbosity
tf.logging.set_verbosity(WARNING)
from dh_segment import estimator_fn, utils
from dh_segment.io import input
import json
from glob import glob
import numpy as np
import mask_unused_gpus
try:
import better_exceptions
except ImportError:
print('/!\ W -- Not able to import package better_exceptions')
pass
from tqdm import trange
from sacred import Experiment
import pandas as pd
ex = Experiment('dhSegment_experiment')
@ex.config
def default_config():
train_data = None # Directory with training data
eval_data = None # Directory with validation data
model_output_dir = None # Directory to output tf model
restore_model = False # Set to true to continue training
classes_file = None # txt file with classes values (unused for REGRESSION)
gpu = '' # GPU to be used for training
prediction_type = utils.PredictionType.CLASSIFICATION # One of CLASSIFICATION, REGRESSION or MULTILABEL
pretrained_model_name = 'resnet50'
model_params = utils.ModelParams(pretrained_model_name=pretrained_model_name).to_dict() # Model parameters
training_params = utils.TrainingParams().to_dict() # Training parameters
msi_params = utils.MSIParams().to_dict() # Training parameters
use_ms = True
if prediction_type == utils.PredictionType.CLASSIFICATION:
assert classes_file is not None
model_params['n_classes'] = utils.get_n_classes_from_file(classes_file)
elif prediction_type == utils.PredictionType.REGRESSION:
model_params['n_classes'] = 1
elif prediction_type == utils.PredictionType.MULTILABEL:
assert classes_file is not None
model_params['n_classes'] = utils.get_n_classes_from_file_multilabel(classes_file)
@ex.automain
def run(train_data, eval_data, model_output_dir, gpu, training_params, _config, msi_params):
# Create output directory
if not os.path.isdir(model_output_dir):
os.makedirs(model_output_dir)
else:
assert _config.get('restore_model'), \
'{0} already exists, you cannot use it as output directory. ' \
'Set "restore_model=True" to continue training, or delete dir "rm -r {0}"'.format(model_output_dir)
# Save config
with open(os.path.join(model_output_dir, 'config.json'), 'w') as f:
json.dump(_config, f, indent=4, sort_keys=True)
# Create export directory for saved models
saved_model_dir = os.path.join(model_output_dir, 'export')
if not os.path.isdir(saved_model_dir):
os.makedirs(saved_model_dir)
training_params = utils.TrainingParams.from_dict(training_params)
msi_params = utils.MSIParams.from_dict(msi_params)
session_config = tf.ConfigProto()
# Old version to use a fixed gpu:
# session_config.gpu_options.visible_device_list = str(gpu)
# session_config.gpu_options.per_process_gpu_memory_fraction = 0.9
# new version to use a free gpu:
mask_unused_gpus.mask_unused_gpus(2)
estimator_config = tf.estimator.RunConfig().replace(session_config=session_config,
save_summary_steps=10,
keep_checkpoint_max=1)
estimator = tf.estimator.Estimator(estimator_fn.model_fn, model_dir=model_output_dir,
params=_config, config=estimator_config)
def get_dirs_or_files(input_data):
if os.path.isdir(input_data):
train_input, train_labels_input = os.path.join(input_data, 'images'), os.path.join(input_data, 'labels')
# Check if training dir exists
if not os.path.isdir(train_input):
raise FileNotFoundError(train_input)
if not os.path.isdir(train_labels_input):
raise FileNotFoundError(train_labels_input)
elif os.path.isfile(train_data) and train_data.endswith('.csv'):
train_input = train_data
train_labels_input = None
else:
raise TypeError('input_data {} is neither a directory nor a csv file'.format(input_data))
return train_input, train_labels_input
train_input, train_labels_input = get_dirs_or_files(train_data)
if eval_data is not None:
eval_input, eval_labels_input = get_dirs_or_files(eval_data)
# Configure exporter
serving_input_fn = input.serving_input_filename(training_params.input_resized_size, msi_params.channel_ids, msi_params.separator)
exporter = tf.estimator.BestExporter(serving_input_receiver_fn=serving_input_fn, exports_to_keep=2)
for i in trange(0, training_params.n_epochs, training_params.evaluate_every_epoch, desc='Evaluated epochs'):
estimator.train(input.input_fn(train_input,
input_label_dir=train_labels_input,
num_epochs=training_params.evaluate_every_epoch,
batch_size=training_params.batch_size,
data_augmentation=training_params.data_augmentation,
make_patches=training_params.make_patches,
image_summaries=True,
params=_config,
num_threads=32))
if eval_data is not None:
eval_result = estimator.evaluate(input.input_fn(eval_input,
input_label_dir=eval_labels_input,
batch_size=1,
data_augmentation=False,
make_patches=False,
image_summaries=False,
params=_config,
num_threads=32))
else:
eval_result = None
exporter.export(estimator, saved_model_dir, checkpoint_path=None, eval_result=eval_result,
is_the_final_export=False)