-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain_model.py
161 lines (127 loc) · 7.53 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import argparse
import importlib
import os
from glob import glob
import h5py
import pickle
import torch
import iotools
def train_model(args) :
# Set random seed
if args.random_seed is not None :
print("Setting random seed to {0}".args.random_seed)
torch.manual_seed(args.random_seed)
torch.cuda.manual_seed(args.random_seed)
torch.cuda.manual_seed_all(args.random_seed) # if you are using multi-GPU.
torch.manual_seed(args.random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
np.random.seed(args.random_seed)
# Get and initialize model
print("Loading model: "+args.model)
# Collect model options in a dictionary
model_args_dict = {}
if len(args.model_arguments) :
for model_arg in args.model_arguments :
arg_split = model_arg.split(":")
try :
arg_value = float(arg_split[1])
if arg_value.is_integer() :
arg_value = int(arg_value)
except :
arg_value = arg_split[1]
model_args_dict[arg_split[0]] = arg_value
print("With options: ", model_args_dict)
# import model
model_module = importlib.import_module("models."+args.model)
# Initialize model
network = model_module.model(**model_args_dict)
if args.network_state is not None :
print("Loading net")
network.load_state_dict(torch.load(args.network_state, map_location=network.device))
if args.optimizer_state is not None :
print("Loading opt")
network.optimizer.load_state_dict(torch.load(args.optimizer_state, map_location=network.device))
# Initialize data loaders
print("Data directory: "+args.data_dirs)
print("Data flavour: "+args.data_flavour)
train_loader=iotools.loader_factory('H5Dataset', batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory = True, data_dirs=args.data_dirs.split(","), flavour=args.data_flavour, start_fraction=0.0, use_fraction=args.train_fraction, read_keys= ["positions","directions", "energies", "event_data_top", "event_data_bottom"])
test_loader =iotools.loader_factory('H5Dataset', batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory = True, data_dirs=args.data_dirs.split(","), flavour=args.data_flavour, start_fraction=args.train_fraction, use_fraction=1.-args.train_fraction, read_keys= ["positions","directions", "energies", "event_data_top", "event_data_bottom"])
# Grab end-cap masks from one of the input files
with h5py.File(glob(args.data_dirs+"/*"+args.data_flavour)[0], mode = "r") as f :
network.top_mask = f['mask'][0]
network.top_mask = network.top_mask.reshape(-1, network.top_mask.shape[0]*network.top_mask.shape[1])
network.bottom_mask = f['mask'][1]
network.bottom_mask = network.bottom_mask.reshape(-1, network.bottom_mask.shape[0]*network.bottom_mask.shape[1])
# Create output directory
try :
os.makedirs(args.output_dir)
except FileExistsError :
pass
# Save training options to file
with open(args.output_dir+"/"+args.model+"_config.p", "wb") as f_out_conf :
pickle.dump(args, f_out_conf)
# Lists to store training progress
train_record = []
test_record = []
# Training loop
current_epoch = 0.
global_iteration = -1
network.train()
while current_epoch < args.epochs :
print("STARTING EPOCH {0}".format(current_epoch))
for iteration, data in enumerate(train_loader) :
global_iteration += 1
network.fillData(data)
network.fillLabel(data)
res = network.evaluate(True)
network.backward()
current_epoch += 1./len(train_loader)
res.update({'epoch' : current_epoch, 'iteration' : global_iteration})
train_record.append(res)
# Report progress
if global_iteration == 0 or (global_iteration+1)%10 == 0 :
print('TRAINING', 'Iteration', global_iteration, 'Epoch', current_epoch, 'Loss', res['loss'], res['loss_breakdown'])
if (global_iteration+1)%100 == 0 :
with torch.no_grad() :
network.eval()
test_data = next(iter(test_loader))
network.fillLabel(test_data)
network.fillData(test_data)
res = network.evaluate(False)
res.update({'epoch' : current_epoch, 'iteration' : global_iteration})
test_record.append(res)
print('VALIDATION', 'Iteration', global_iteration, 'Epoch', current_epoch, 'Loss', res['loss'], res['loss_breakdown'])
network.train()
# Save network periodically
if (global_iteration+1)%args.save_interval == 0 :
print("Saving network state")
torch.save(network.state_dict(), args.output_dir+"/"+args.model+"_"+str(global_iteration)+".cnn")
torch.save(network.optimizer.state_dict(), args.output_dir+"/"+args.model+"_optimizer_"+str(global_iteration)+".cnn")
if current_epoch >= args.epochs :
break
torch.save(network.state_dict(), args.output_dir+"/"+args.model+".cnn")
torch.save(network.optimizer.state_dict(), args.output_dir+"/"+args.model+"_optimizer.cnn")
with open(args.output_dir+"/"+args.model+"_train_record.p", "wb") as f :
pickle.dump(train_record, f)
with open(args.output_dir+"/"+args.model+"_test_record.p", "wb") as f :
pickle.dump(test_record, f)
print("Training done")
if __name__ == "__main__" :
parser = argparse.ArgumentParser(description='Application to train Water Cherenkov generative neural networks.')
parser.add_argument('-e', '--epochs', type = float, help = "Number of epochs to train for", default = 1., required = False)
parser.add_argument('-b', '--batch_size', type = int, help = "Batch size", default = 200, required = False)
parser.add_argument('-j', '--num_workers', type = int, help = "Number of CPUs for loading data", default = 8, required = False)
parser.add_argument('-t', '--train_fraction', type = float, help = "Fraction of data used for training", default = 0.75, required = False)
parser.add_argument('-s', '--save_interval', type = int, help = "Save network state every <save_interval> iterations", default = 5000, required = False)
parser.add_argument('-o', '--output_dir', type = str, help = "Output directory", default = "./", required = False)
parser.add_argument('-r', '--random_seed', type = int, help = "Random seed", default = None, required = False)
parser.add_argument('--network_state', type = str, help = "Path to network state to load (for continued training)", default = None, required = False)
parser.add_argument('--optimizer_state', type = str, help = "Path to optimizer state to load (for continued training)", default = None, required = False)
parser.add_argument('data_dirs', type = str, help = "Directory with training data")
parser.add_argument('data_flavour', type = str, help = "Expression that matches training data file ending")
parser.add_argument('model', type = str, help = "Name of model to train")
parser.add_argument('model_arguments', type = str, help = "Arguments to pass to model, in format \"name1:value1 name2:value2 ...\"", nargs = "*", default = "")
args = parser.parse_args()
print(args)
train_model(args)