-
Notifications
You must be signed in to change notification settings - Fork 0
/
dedx_flow.py
440 lines (392 loc) · 20.7 KB
/
dedx_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
###################################### Imports #################################################
import argparse
import os
import time
import torch
torch.set_default_dtype(torch.float32)
import torch.nn.functional as F
import numpy as np
import sys
sys.path.append('/hpcfs/bes/mlgpu/xingty/myliu/Other/nflows/')
from nflows import transforms, distributions, flows
from data import get_dataloader
from data import save_samples_to_file, save_calib_to_file, save_dedx_to_file
#from model_calib import Nett, Net_val
##################################### Parser setup #############################################
parser = argparse.ArgumentParser()
# usage modes
parser.add_argument('--training', action='store_true', help='train calib')
parser.add_argument('--generate_to_file', action='store_true', help='generate from a trained flow and save to file')
parser.add_argument('--save_pt', action='store_true', help='')
parser.add_argument('--check_pt', action='store_true', help='')
parser.add_argument('--restore', action='store_true', help='restore and train a flow')
parser.add_argument('--no_cuda', action='store_true', help='Do not use cuda.')
parser.add_argument('--which_cuda', default=0, type=int, help='Which cuda device to use')
parser.add_argument('--output_dir', default='./results', help='Where to store the output')
parser.add_argument('--gen_dir', default='', help='Where to store the generated file')
parser.add_argument('--output_file', default='default.hdf5', help='')
parser.add_argument('--results_file', default='results.txt', help='Filename where to store settings and test results.')
parser.add_argument('--restore_file', type=str, default=None, help='Model file to restore.')
parser.add_argument('--data_dir', default='', help='Where to find the training dataset')
parser.add_argument('--use_test_dataloader', action='store_true', help='')
parser.add_argument('--particle_type', '-p', help='Which particle, "e+", "eplus", or "piplus"')
parser.add_argument('--layerID', '-layer', help='layerID')
parser.add_argument('--num_feature', default=1, type=int, help='How many features are trained')
parser.add_argument('--num_block' , default=2, type=int, help='')
parser.add_argument('--hidden_features' , default=64, type=int, help='')
parser.add_argument('--num_epochs', default=100, type=int, help='How many epochs are trained')
parser.add_argument('--gen_events', default=100, type=int, help='How many events are generated')
# MAF parameters
parser.add_argument('--n_blocks', type=str, default='8',
help='Total number of blocks to stack in a model (MADE in MAF).')
parser.add_argument('--hidden_size_multiplier', type=int, default=None,
help='Hidden layer size for each MADE block in an MAF'+\
' is given by the dimension times this factor.')
parser.add_argument('--n_hidden', type=int, default=1,
help='Number of hidden layers in each MADE.')
parser.add_argument('--activation_fn', type=str, default='relu',
help='What activation function of torch.nn.functional to use in the MADEs.')
parser.add_argument('--batch_norm', action='store_true', default=False,
help='Use batch normalization')
parser.add_argument('--n_bins', type=int, default=8,
help='Number of bins if piecewise transforms are used')
parser.add_argument('--use_residual', action='store_true', default=False,
help='Use residual layers in the NNs')
parser.add_argument('--dropout_probability', '-d', type=float, default=0.05,
help='dropout probability')
parser.add_argument('--tail_bound', type=float, default=14., help='Domain of the RQS')
parser.add_argument('--cond_base', action='store_true', default=False,
help='Use Gaussians conditioned on energy as base distribution.')
parser.add_argument('--init_id', action='store_true',
help='Initialize Flow to be identity transform')
# training params
parser.add_argument('--batch_size', type=int, default=4096)
parser.add_argument('--num_try', type=int, default=1)
parser.add_argument('--n_epochs', type=int, default=100)
parser.add_argument('--log_interval', type=int, default=100,
help='How often to show loss statistics and save samples.')
parser.add_argument('--workpath', type=str, default='.', help='work path ')
parser.add_argument('--save_model_name', type=str, default='.', help='save_model_name ')
parser.add_argument('--schedulerType', type=int, default=0)
parser.add_argument('--input_lr', type=float, default=0.0, help='')
parser.add_argument('--w_decay', type=float, default=0, help='')
parser.add_argument('--trunc', type=float, default=0, help='')
parser.add_argument('--gen_batch', type=int, default=1000)
parser.add_argument('--pt_file_path', type=str, default='.', help='args.pt_file_path ')
parser.add_argument('--check_pt_file', type=str, default='.', help=' ')
####################################### helper functions #######################################
class LRWarmUPSF(object):
def __init__(self, optimizer, warmup_iteration, target_lr, threshold, sf):
self.optimizer = optimizer
self.warmup_iteration = warmup_iteration
self.target_lr = target_lr
self.previous_loss = 9999
self.change_threshold = threshold
self.sf = sf
def warmup_learning_rate(self, cur_iteration):
warmup_lr = self.target_lr*float(cur_iteration)/float(self.warmup_iteration)
for param_group in self.optimizer.param_groups:
param_group['lr'] = warmup_lr
def sf_learning_rate(self, loss):
if (self.previous_loss - loss) > self.change_threshold:
pass
else:
for param_group in self.optimizer.param_groups:
param_group['lr'] = param_group['lr']*self.sf
self.previous_loss = loss
def step(self, cur_iteration, loss):
if cur_iteration <= self.warmup_iteration:
self.warmup_learning_rate(cur_iteration)
else:
self.sf_learning_rate(loss)
# used in transformation between energy and logit space:
# (should match the ALPHA in data.py)
ALPHA = 1e-6
def logit(x):
""" returns logit of input """
return torch.log(x / (1.0 - x))
def logit_trafo(x):
""" implements logit trafo of MAF paper https://arxiv.org/pdf/1705.07057.pdf """
local_x = ALPHA + (1. - 2.*ALPHA) * x
return logit(local_x)
def inverse_logit(x, clamp_low=0., clamp_high=1.):
""" inverts logit_trafo(), clips result if needed """
return ((torch.sigmoid(x) - ALPHA) / (1. - 2.*ALPHA)).clamp_(clamp_low, clamp_high)
def one_hot(values, num_bins):
""" one-hot encoding of values into num_bins """
# values are energies in [0, 1], need to be converted to integers in [0, num_bins-1]
values *= num_bins
values = values.type(torch.long)
ret = F.one_hot(values, num_bins)
return ret.squeeze().double()
def one_blob(values, num_bins):
""" one-blob encoding of values into num_bins, cf sec. 4.3 of 1808.03856 """
# torch.tile() not yet in stable release, use numpy instead
values = values.cpu().numpy()[..., np.newaxis]
y = np.tile(((0.5/num_bins) + np.arange(0., 1., step=1./num_bins)), values.shape)
res = np.exp(((-num_bins*num_bins)/2.)
* (y-values)**2)
res = np.reshape(res, (-1, values.shape[-1]*num_bins))
return torch.tensor(res)
def remove_nans(tensor):
"""removes elements in the given batch that contain nans
returns the new tensor and the number of removed elements"""
tensor_flat = tensor.flatten(start_dim=1)
good_entries = torch.all(tensor_flat == tensor_flat, axis=1)
res_flat = tensor_flat[good_entries]
tensor_shape = list(tensor.size())
tensor_shape[0] = -1
res = res_flat.reshape(tensor_shape)
return res, len(tensor) - len(res)
def generate_to_file(args, num_events, sim_model, data_loader):
filename = os.path.join(args.gen_dir, args.output_file)
if not os.path.isdir(args.gen_dir):
os.makedirs(args.gen_dir)
generating(args, args.num_try, num_events, sim_model, data_loader, filename)
def save_model(model, optimizer, args):
torch.save({'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict()}, os.path.join(args.output_dir, args.save_model_name))
def load_model(model, optimizer, args):
checkpoint = torch.load(os.path.join(args.output_dir, args.restore_file))
model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.to(args.device)
model.eval()
################################# auxilliary NNs and classes #######################################
class ContextEmbedder(torch.nn.Module):
""" Small NN to be used for the embedding of the conditionals """
def __init__(self, input_size, output_size):
""" input_size: length of context vector
output_size: length of context vector to be fed to the flow
"""
super(ContextEmbedder, self).__init__()
self.layer1 = torch.nn.Linear(input_size, (input_size+output_size)//2)
self.layer2 = torch.nn.Linear((input_size+output_size)//2, (input_size+output_size)//2)
self.output = torch.nn.Linear((input_size+output_size)//2, output_size)
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
out = self.output(x)
return out
@torch.no_grad()
def generating(args, num_pts, N_evt, sim_model, data_loader, filename):
y_label = None
dedx_ori = None
isFirst = True
for i, data in enumerate(data_loader):
if isFirst:
y_label = data['label']
dedx_ori= data['dedx']
isFirst = False
else:
dedx_ori = torch.cat((dedx_ori, data['dedx']), 0)
y_label = torch.cat((y_label , data['label']), 0)
if dedx_ori.size()[0] >= N_evt:break
dedx_ori = dedx_ori[0:N_evt]
y_label = y_label [0:N_evt]
for i in range(0, y_label.size()[0], args.gen_batch):
tmp_label = y_label [i:i+args.gen_batch,:]
#print('tmp_label = ',tmp_label[:,0],file=open(args.results_file,'a'))
#print('tmp_label = ',tmp_label[:,1],file=open(args.results_file,'a'))
#tmp_label[:,0] = (tmp_label[:,0]-0.5)/0.5
#tmp_label[:,1] = 0.8*(tmp_label[:,1]-0.5)
#print('tmp_label = ',tmp_label[:,0],file=open(args.results_file,'a'))
#print('tmp_label scale = ',tmp_label[:,0],file=open(args.results_file,'a'))
#print('tmp_label = ',tmp_label[:,1],file=open(args.results_file,'a'))
#print('tmp_label scale = ',tmp_label[:,1],file=open(args.results_file,'a'))
tmp_dedx_ori= dedx_ori[i:i+args.gen_batch,:]
#print(' tmp_dedx_ori = ',tmp_dedx_ori,file=open(args.results_file,'a'))
dedx_dist_unit = sample_flow(sim_model, num_pts, args, tmp_label).to('cpu') if args.check_pt == False else jit_sample_flow(sim_model, args, tmp_label)
#print('dedx_dist_unit = ',dedx_dist_unit,file=open(args.results_file,'a'))
dedx_dist = dedx_scale*dedx_dist_unit
#print('dedx_dist = ',dedx_dist,file=open(args.results_file, 'a'))
tmp_filename = filename.replace('.hdf5','_%d.hdf5'%( int(i/args.gen_batch) ) )
#print('save to %s'%tmp_filename, file=open(args.results_file, 'a'))
#print('dedx_dist =', dedx_dist.size(),',tmp_dedx_ori=',tmp_dedx_ori.size(),',tmp_label=',tmp_label.size(),',tmp_filename=',tmp_filename)
dedx_dist = torch.squeeze(dedx_dist,2)
save_samples_to_file(dedx_sim=dedx_dist, dedx_ori=tmp_dedx_ori, real_label=tmp_label, filename=tmp_filename)
def train_flow(sim_model, train_data, test_data, optim, args):
""" trains the flow that learns the distributions """
best_eval_logprob_rec = float('-inf')
lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optim, milestones=[5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100], gamma=0.5)##py1.60
if args.schedulerType == 1:
peak_lr = 0.001
warmup_iter = 5 if args.restore == False else -1
eps = 1e-4
sf = 0.7
lr_schedule = LRWarmUPSF(optimizer=optim, warmup_iteration=warmup_iter, target_lr=peak_lr, threshold=eps, sf=sf)
lr_schedule.step(1,999)
if args.input_lr !=0:
for param_group in optim.param_groups:
param_group['lr'] = input_lr
for epoch in range(args.num_epochs):
sim_model.train()
loglike_train = []
tmp_lr = 0
for param_group in optim.param_groups:
tmp_lr = param_group['lr']
for i, data in enumerate(train_data):
x0 = data['dedx']
x0 = x0.float()
y = data['label'].to(args.device)
#print(x0.type(),y.type())
#print('x0 = ',x0,file=open(args.results_file, 'a'))
x = (x0/dedx_scale).clamp_(0., 1.).to(args.device)
#print('x = ',x,file=open(args.results_file, 'a'))
x = logit_trafo(x)
#print(x.type(),y.type())
#print('logit_trafo x = ',x,file=open(args.results_file, 'a'))
loss = - sim_model.log_prob(x, y).mean(0)
optim.zero_grad()
loss.backward()
optim.step()
loglike_train.append(loss.item())
logprob_mean_train = np.mean(loglike_train)
logprob_std_train = np.std(loglike_train)
output = 'Flow: Training (epoch {}) -- '.format(epoch+1) + '-logp = {:.3f} +/- {:.3f}, lr is {:f}'
print(output.format(logprob_mean_train, logprob_std_train, tmp_lr), file=open(args.results_file, 'a'))
logprob_mean = float('-inf')
with torch.no_grad():
sim_model.eval()
loglike = []
for data in test_data:
x0 = data['dedx']
x0 = x0.float()
y = data['label'].to(args.device)
x = (x0/dedx_scale).clamp_(0., 1.).to(args.device)
x = logit_trafo(x)
loglike.append(sim_model.log_prob(x, y))
logprobs = torch.cat(loglike, dim=0).to(args.device)
logprob_mean = logprobs.mean(0)
logprob_std = logprobs.var(0).sqrt()
output = 'Flow: Evaluate (epoch {}) -- '.format(epoch+1) + 'logp = {:.3f} +/- {:.3f}, lr is {:f}'
print(output.format(logprob_mean, logprob_std, tmp_lr), file=open(args.results_file, 'a'))
if args.schedulerType == 1:
lr_schedule.step(epoch+2,logprob_mean_train)
else:
lr_schedule.step()
if logprob_mean >= best_eval_logprob_rec:
best_eval_logprob_rec = logprob_mean
save_model(sim_model,optim, args)
@torch.no_grad()
def sample_flow(sim_model, num_pts, args, label_data=None):
sim_model.eval()
samples = sim_model.sample(num_pts, label_data.to(args.device))
#samples = inverse_logit(samples.squeeze())
samples = inverse_logit(samples)
return samples
@torch.no_grad()
def save_flow(sim_model, args):
sim_model.eval()
tmp_device = torch.device("cpu")
sim_model.to(tmp_device)
#sim_model.double()##change to double
sim_model.float()##change to double
example_input = torch.tensor( [[0.0784, 0.4385, 0.8700, 0.3949]] )
module = torch.jit.trace_module(sim_model, {'forward':example_input.float()})
#module.double()##change to double
module.float()##change to double
module.save(args.pt_file_path)
print('saved %s'%args.pt_file_path)
@torch.no_grad()
def jit_sample_flow(sim_model, args, label_data):
sim_model.eval()
tmp_device = torch.device("cpu")
label_data = label_data.to(tmp_device)
noise = torch.randn(label_data.size()[0], 1)
example_input = torch.cat((label_data,noise), 1)
samples = sim_model.forward(example_input)
samples = inverse_logit(samples)##(1,1)
return samples
####################################################################################################
####################################### running the code #######################################
####################################################################################################
if __name__ == '__main__':
args = parser.parse_args()
cond_label_size = 3#mom,theta,nhit
# check if output_dir exists and 'move' results file there
if not os.path.isdir(args.output_dir):
os.makedirs(args.output_dir)
args.results_file = os.path.join(args.output_dir, args.results_file)
print(args, file=open(args.results_file, 'a'))
# setup device
args.device = torch.device('cuda:'+str(args.which_cuda) if torch.cuda.is_available() and not args.no_cuda else 'cpu')
print("Using {}".format(args.device))
print("Using {}".format(args.device), file=open(args.results_file, 'a'))
# get dataloaders
train_dataloader, test_dataloader = get_dataloader(args.particle_type,
args.layerID,
args.data_dir,
device=args.device,
batch_size=args.batch_size)
dedx_scale = 1
if args.particle_type == 'p+' or args.particle_type == 'p-':
dedx_scale = 22000 # for betagamma = 0.1
elif args.particle_type == 'pi+' or args.particle_type == 'pi-':
dedx_scale = 2000
elif args.particle_type == 'k+' or args.particle_type == 'k-':
dedx_scale = 5000
elif args.particle_type == 'e+' or args.particle_type == 'e-':
dedx_scale = 1800
else:
raise ValueError("Wrong dedx_scale")
flow_params = {'num_blocks': args.num_block, #num of layers per block, default 2
'features': args.num_feature,
'context_features': 2, #1,
'hidden_features': args.hidden_features, #default is 64
'use_residual_blocks': False,
'use_batch_norm': False,
'dropout_probability': 0.,
'activation':getattr(F, args.activation_fn),
'random_mask': False,
'num_bins': 8,
'tails':'linear',
'tail_bound': 14,
'min_bin_width': 1e-6,
'min_bin_height': 1e-6,
'min_derivative': 1e-6}
flow_blocks = []
for _ in range(6):
flow_blocks.append(
transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
**flow_params))
flow_blocks.append(transforms.RandomPermutation(args.num_feature))
flow_transform = transforms.CompositeTransform(flow_blocks)
# _sample not implemented:
flow_base_distribution = distributions.StandardNormal(shape=[args.num_feature])
flow = flows.Flow(transform=flow_transform, distribution=flow_base_distribution)
sim_model = flow.to(args.device)
sim_optimizer = torch.optim.Adam(sim_model.parameters(), lr=0.001)
print(sim_model)
print(sim_model, file=open(args.results_file, 'a'))
total_parameters = sum(p.numel() for p in sim_model.parameters() if p.requires_grad)
print("setup has {} parameters".format(int(total_parameters)))
print("test")
print("args.training = ",args.training)
print("generate_to_file = ",generate_to_file)
print("setup has {} parameters".format(int(total_parameters)),
file=open(args.results_file, 'a'))
if args.training:
print("do training:", file=open(args.results_file, 'a'))
if args.restore:
print("restoreing from %s"%args.restore_file, file=open(args.results_file, 'a'))
load_model(sim_model, sim_optimizer, args)
train_flow(sim_model, train_dataloader, test_dataloader, sim_optimizer, args)
if args.generate_to_file:
print("do test: ", file=open(args.results_file, 'a'))
load_model(sim_model, sim_optimizer, args)
if args.use_test_dataloader == False:
print("if use_test_dataloader = ",0)
generate_to_file(args, args.gen_events, sim_model=sim_model, data_loader=train_dataloader)
else:
print("else use_test_dataloader = ",1)
generate_to_file(args, args.gen_events, sim_model=sim_model, data_loader=test_dataloader)
if args.save_pt:
load_model(sim_model, sim_optimizer, args)
save_flow(sim_model, args)
if args.check_pt:
m_net = torch.jit.load(args.check_pt_file)
example_input = torch.tensor( [[0.0784, 0.4385, 0.8700]] )
output = m_net.forward( example_input.float() )
print('example_input=',example_input,',output=',output)
#generate_to_file(args, num_events=1000000, sim_model=m_net, data_loader=train_dataloader)
#generate_to_file(args, num_events=10, sim_model=m_net, data_loader=train_dataloader)