-
Notifications
You must be signed in to change notification settings - Fork 5
/
train_pelican_cov.py
175 lines (140 loc) · 7.69 KB
/
train_pelican_cov.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import logging
import os
import sys
import numpy
import random
from src.trainer import which
if which('nvidia-smi') is not None:
min=8000
deviceid = 0
name, mem = os.popen('"nvidia-smi" --query-gpu=gpu_name,memory.total --format=csv,nounits,noheader').read().split('\n')[deviceid].split(',')
print(mem)
mem = int(mem)
if mem < min:
print(f'Less GPU memory than requested ({mem}<{min}). Terminating.')
sys.exit()
import torch
from torch.utils.data import DataLoader
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from src.models import PELICANRegression
from src.models import tests, expand_data, ir_data, irc_data
from src.trainer import Trainer
from src.trainer import init_argparse, init_file_paths, init_logger, init_cuda, logging_printout, fix_args, set_seed, get_world_size
from src.trainer import init_optimizer, init_scheduler
from src.models.metrics_cov import metrics, minibatch_metrics, minibatch_metrics_string
from src.models.metrics_cov import loss_fn_dR, loss_fn_pT, loss_fn_m, loss_fn_psi, loss_fn_inv, loss_fn_col, loss_fn_m2, loss_fn_3d, loss_fn_4d, loss_fn_E, loss_fn_col3
from src.dataloaders import initialize_datasets, collate_fn
# This makes printing tensors more readable.
torch.set_printoptions(linewidth=1000, threshold=100000, sci_mode=False)
def main():
# Initialize arguments -- Just
args = init_argparse()
# Initialize file paths
args = init_file_paths(args)
# Fix possible inconsistencies in arguments
# args = fix_args(args)
if 'LOCAL_RANK' in os.environ:
device_id = int(os.environ["LOCAL_RANK"])
else:
device_id = -1
# Initialize logger
logger = logging.getLogger('')
init_logger(args, device_id)
if which('nvidia-smi') is not None:
logger.info(f'Using {name} with {mem} MB of GPU memory (local rank {device_id})')
# Write input paramaters and paths to log
if device_id <= 0:
logging_printout(args)
# Initialize device and data type
device, dtype = init_cuda(args, device_id)
# Fix possible inconsistencies in arguments
args = fix_args(args)
# Set a manual random seed for torch, cuda, numpy, and random
args = set_seed(args, device_id)
distributed = (get_world_size() > 1)
if distributed:
world_size = dist.get_world_size()
logger.info(f'World size {world_size}')
# Initialize dataloder
if args.fix_data:
torch.manual_seed(165937750084982)
args, datasets = initialize_datasets(args, args.datadir, num_pts=None, testfile=args.testfile, RAMdataset=args.RAMdataset)
# Construct PyTorch dataloaders from datasets
collate = lambda data: collate_fn(data, scale=args.scale, nobj=args.nobj)
distribute_eval=args.distribute_eval
if distributed:
samplers = {'train': DistributedSampler(datasets['train'], shuffle=args.shuffle),
'valid': DistributedSampler(datasets['valid'], shuffle=False),
'test': DistributedSampler(datasets['test'], shuffle=False) if distribute_eval else None}
else:
samplers = {split: None for split in datasets.keys()}
dataloaders = {split: DataLoader(dataset,
batch_size = args.batch_size,
shuffle = args.shuffle if (split == 'train' and not distributed) else False,
num_workers = args.num_workers,
pin_memory=True,
worker_init_fn = seed_worker,
collate_fn =collate,
sampler = samplers[split]
)
for split, dataset in datasets.items()}
# Initialize model
model = PELICANRegression(args.rank1_width_multiplier, args.num_channels_scalar, args.num_channels_m, args.num_channels_2to2, args.num_channels_out, args.num_channels_m_out,
num_targets=args.num_targets, stabilizer=args.stabilizer, method = args.method,
activate_agg=args.activate_agg, activate_lin=args.activate_lin,
activation=args.activation, config=args.config, config_out=args.config_out, average_nobj=args.nobj_avg,
factorize=args.factorize, masked=args.masked,
activate_agg_out=args.activate_agg_out, activate_lin_out=args.activate_lin_out, mlp_out=args.mlp_out,
scale=args.scale, irc_safe=args.irc_safe, dropout = args.dropout, drop_rate=args.drop_rate, drop_rate_out=args.drop_rate_out, batchnorm=args.batchnorm,
dataset=args.dataset, device=device, dtype=dtype)
model.to(device)
if distributed:
model = DistributedDataParallel(model, device_ids=[device_id])
# Initialize the scheduler and optimizer
if args.task.startswith('eval'):
optimizer = scheduler = None
restart_epochs = []
args.summarize = False
else:
optimizer = init_optimizer(args, model, len(dataloaders['train']))
scheduler, restart_epochs = init_scheduler(args, optimizer)
# Define a loss function. This is the loss function whose gradients are actually computed.
loss_fn = lambda predict, targets: 0.05 * loss_fn_m(predict,targets) + 0.01 * loss_fn_3d(predict, targets) #0.01 * loss_fn_E(predict, targets) + 10 * loss_fn_psi(predict,targets) # # #+ 0.02 * loss_fn_E(predict,targets) # #+ + 0.01 * loss_fn_pT(predict,targets) # #0.03 * loss_fn_inv(predict,targets) +
# loss_fn = lambda predict, targets: 0.0005 * loss_fn_col(predict,targets) + 0.01*(-predict[...,0]).relu().mean() + 0.001 * loss_fn_inv(predict,targets) # 0.1 * loss_fn_m(predict,targets)
# Apply the covariance and permutation invariance tests.
if args.test and device_id <= 0:
with torch.autograd.set_detect_anomaly(True):
tests(model, dataloaders['train'], args, tests=['gpu','irc', 'permutation'], cov=True)
# Instantiate the training class
trainer = Trainer(args, dataloaders, model, loss_fn, metrics,
minibatch_metrics, minibatch_metrics_string, optimizer, scheduler,
restart_epochs, device_id, device, dtype)
if not args.task.startswith('eval'):
# Load from checkpoint file (if one exists)
trainer.load_checkpoint()
# Restore random seed
args = set_seed(args, device_id)
# This makes the results exactly reproducible on a GPU (on CPU they're reproducible regardless) by banning certain non-deterministic operations
if args.reproducible:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
# Train model.
trainer.train()
# Test predictions on best model and/or also last checkpointed model.
trainer.evaluate(splits=['test'], distributed=distributed and distribute_eval)
if args.test:
args.predict = False
trainer.summarize_csv = False
logger.info(f'EVALUATING BEST MODEL ON IR-SPLIT DATA (ADDED ONE 0-MOMENTUM PARTICLE)')
trainer.evaluate(splits=['test'], final=False, ir_data=ir_data, expand_data=expand_data)
logger.info(f'EVALUATING BEST MODEL ON IRC-SPLIT DATA (ADD A NEW PARTICLE SLOT AND SPLIT ONE BEAM INTO TWO EQUAL HALVES)')
trainer.evaluate(splits=['test'], final=False, c_data=irc_data, expand_data=expand_data)
if distributed:
dist.destroy_process_group()
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
numpy.random.seed(worker_seed)
random.seed(worker_seed)
if __name__ == '__main__':
main()