-
Notifications
You must be signed in to change notification settings - Fork 35
/
main.py
438 lines (380 loc) · 18.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
import logging
import os
import sys
import time
from argparse import ArgumentParser
from pathlib import Path
import tqdm
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.logging import TestTubeLogger
import torch
from torch import nn, optim
from torchvision.utils import make_grid
from torchvision.transforms import ToTensor
from torch_geometric.data import DataLoader
from mesh_utils import plot_field, is_ccw, is_cw
from su2torch import activate_su2_mpi
from data import MeshAirfoilDataset
from models import CFDGCN, MeshGCN, UCM, CFD
def parse_args():
parser = ArgumentParser()
parser.add_argument('--exp-name', '-e', default='',
help='Experiment name, defaults to model name.')
parser.add_argument('--su2-config', '-sc', default='coarse_train.cfg')
parser.add_argument('--data-dir', '-d', default='data/128px_NACA0012_fine_3NN',
help='Directory with dataset.')
parser.add_argument('--coarse-mesh', default=None,
help='Path to coarse mesh (required for CFD-GCN).')
parser.add_argument('--version', type=int, default=None,
help='If specified log version doesnt exist, create it.'
' If it exists, continue from where it stopped.')
parser.add_argument('--load-model', '-lm', default='', help='Load previously trained model.')
parser.add_argument('--model', '-m', default='gcn',
help='Which model to use.')
parser.add_argument('--max-epochs', '-me', type=int, default=500,
help='Max number of epochs to train for.')
parser.add_argument('--optim', default='adam', help='Optimizer.')
parser.add_argument('--batch-size', '-bs', type=int, default=16)
parser.add_argument('--learning-rate', '-lr', dest='lr', type=float, default=5e-5)
parser.add_argument('--num-layers', '-nl', type=int, default=6)
parser.add_argument('--num-end-convs', type=int, default=3)
parser.add_argument('--hidden-size', '-hs', type=int, default=512)
parser.add_argument('--freeze-mesh', action='store_true',
help='Do not do any learning on the mesh.')
parser.add_argument('--eval', action='store_true',
help='Skips training, does only eval.')
parser.add_argument('--profile', action='store_true',
help='Run profiler.')
parser.add_argument('--seed', type=int, default=0,
help='Random seed')
parser.add_argument('--gpus', type=int, default=1,
help='Number of gpus to use, 0 for none.')
parser.add_argument('--dataloader-workers', '-dw', type=int, default=4,
help='Number of Pytorch Dataloader workers to use.')
parser.add_argument('--train-val-split', '-tvs', type=float, default=0.9,
help='Percentage of training set to use for training.')
parser.add_argument('--val-check-interval', '-vci', type=int, default=None,
help='Run validation every N batches, '
'defaults to once every epoch.')
parser.add_argument('--early-stop-patience', '-esp', type=int, default=0,
help='Patience before early stopping. '
'Does not early stop by default.')
parser.add_argument('--train-pct', type=float, default=1.0,
help='Run on a reduced percentage of the training set,'
' defaults to running with full data.')
parser.add_argument('--verbose', type=int, default=1, choices=[0, 1],
help='Verbosity level. Defaults to 1, 0 for quiet.')
parser.add_argument('--debug', action='store_true',
help='Run in debug mode. Doesnt write logs. Runs '
'a single iteration of training and validation.')
parser.add_argument('--no-log', action='store_true',
help='Dont save any logs or checkpoints.')
args = parser.parse_args()
args.nodename = os.uname().nodename
if args.exp_name == '':
args.exp_name = args.model
if args.val_check_interval is None:
args.val_check_interval = 1.0
args.distributed_backend = 'dp'
return args
class LightningWrapper(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.step = None # count test step because apparently Trainer doesnt
self.criterion = nn.MSELoss()
self.data = MeshAirfoilDataset(hparams.data_dir, mode='train')
in_channels = self.data[0].x.shape[-1]
out_channels = self.data[0].y.shape[-1]
hidden_channels = hparams.hidden_size
if hparams.model == 'cfd_gcn':
model = CFDGCN(hparams.su2_config,
self.hparams.coarse_mesh,
fine_marker_dict=self.data.marker_dict,
hidden_channels=hidden_channels,
num_convs=self.hparams.num_layers,
num_end_convs=self.hparams.num_end_convs,
out_channels=out_channels,
process_sim=self.data.preprocess,
freeze_mesh=self.hparams.freeze_mesh,
device='cuda' if self.hparams.gpus > 0 else 'cpu')
elif hparams.model == 'gcn':
model = MeshGCN(in_channels,
hidden_channels,
out_channels,
fine_marker_dict=self.data.marker_dict,
num_layers=hparams.num_layers,
improved=False)
elif hparams.model == 'ucm':
model = UCM(hparams.su2_config,
self.hparams.coarse_mesh,
fine_marker_dict=self.data.marker_dict,
process_sim=self.data.preprocess,
freeze_mesh=self.hparams.freeze_mesh,
device='cuda' if self.hparams.gpus > 0 else 'cpu')
elif hparams.model == 'cfd':
model = CFD(hparams.su2_config,
self.hparams.coarse_mesh,
fine_marker_dict=self.data.marker_dict,
process_sim=self.data.preprocess,
freeze_mesh=self.hparams.freeze_mesh,
device='cuda' if self.hparams.gpus > 0 else 'cpu')
else:
raise NotImplementedError
self.model = model
def forward(self, x):
return self.model(x)
def on_epoch_start(self):
logging.info('------')
self.sum_loss = 0.0
def on_epoch_end(self):
avg_loss = self.sum_loss / max(self.trainer.num_training_batches, 1)
train_metrics = {
'train_loss': avg_loss,
}
self.trainer.log_metrics(train_metrics, {}, step=self.trainer.global_step - 1)
if hasattr(train_metrics, 'epoch'):
del train_metrics['epoch'] # added from the method above
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
if isinstance(self.model, CFDGCN) and not self.hparams.freeze_mesh:
if self.model.nodes.grad is not None:
# do not optimize airfoil nodes to maintain shape
self.model.nodes.grad[self.model.marker_inds] = 0
# save prev nodes for mesh checking below
prev_nodes = self.model.nodes.detach().clone()
super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, second_order_closure)
if isinstance(self.model, CFDGCN) and not self.hparams.freeze_mesh:
# check mesh for element flipping, if flipped dont do gradient descent update
nodes = self.model.nodes
elems = self.model.elems[0]
flipped_elems = is_cw(nodes, elems).nonzero()
while flipped_elems.shape[0] > 0:
flipped_inds = [elems[i] for i in flipped_elems]
flipped_inds = torch.tensor(flipped_inds).unique()
with torch.no_grad():
nodes[flipped_inds] = prev_nodes[flipped_inds]
flipped_elems = is_cw(nodes, elems).nonzero()
def common_step(self, batch):
device = 'cuda' if self.hparams.gpus > 0 else 'cpu'
batch = self.transfer_batch_to_device(batch, device)
true_fields = batch.y
pred_fields = self.forward(batch)
mse_loss = self.criterion(pred_fields, true_fields)
sub_losses = {'batch_mse_loss': mse_loss}
loss = mse_loss
return loss, pred_fields, sub_losses
def training_step(self, batch, batch_idx):
loss, pred, sub_losses = self.common_step(batch)
if batch_idx + 1 == self.trainer.val_check_batch:
# log images when doing val check
self.log_images(batch.x[:, :2], pred, batch.y, batch.batch,
self.data.elems_list, 'train')
self.sum_loss += loss.item()
logs = {
'batch_train_loss': loss,
}
logs.update(sub_losses)
output = {
'loss': loss,
'progress_bar': logs,
'log': logs,
}
return output
def validation_step(self, batch, batch_idx):
loss, pred, sub_losses = self.common_step(batch)
if batch_idx == 0:
# log images only once per val epoch
self.log_images(batch.x[:, :2], pred, batch.y, batch.batch, self.data.elems_list, 'val')
output = {
'batch_val_loss': loss,
}
output.update(sub_losses)
return output
def validation_end(self, outputs):
avg_loss = torch.stack([x['batch_val_loss'] for x in outputs]).mean()
logs = {
'val_loss': avg_loss,
}
result = {
'progress_bar': logs,
'log': logs,
}
result.update(logs)
return result
def test_step(self, batch, batch_idx):
loss, pred, sub_losses = self.common_step(batch)
batch_size = batch.batch.max()
self.step = 0 if self.step is None else self.step
for i in range(batch_size):
self.log_images(batch.x[:, :2], pred, batch.y, batch.batch,
self.data.elems_list, 'test', log_idx=i)
self.step += 1
output = {
'batch_test_loss': loss,
}
output.update(sub_losses)
return output
def test_end(self, outputs):
avg_loss = torch.stack([x['batch_test_loss'] for x in outputs]).mean()
self.step = None
logs = {
'test_loss': avg_loss,
}
result = {
'progress_bar': logs,
'log': logs,
}
result.update(logs)
metrics = self.format_metrics_dict(logs)
print(f'Test results: {metrics}', file=sys.stderr)
return result
def configure_optimizers(self):
if self.hparams.optim.lower() == 'adam':
optimizers = [optim.Adam(self.parameters(), lr=self.hparams.lr)]
elif self.hparams.optim.lower() == 'rmsprop':
optimizers = [optim.RMSprop(self.parameters(), lr=self.hparams.lr)]
elif self.hparams.optim.lower() == 'sgd':
optimizers = [optim.SGD(self.parameters(), lr=self.hparams.lr)]
schedulers = []
return optimizers, schedulers
def train_dataloader(self):
train_data = self.data
train_loader = DataLoader(train_data,
batch_size=self.hparams.batch_size,
# dont shuffle if using reduced set
shuffle=(self.hparams.train_pct == 1.0),
num_workers=self.hparams.dataloader_workers)
if self.hparams.verbose:
logging.info(f'Train data: {len(train_data)} examples, '
f'{len(train_loader)} batches.')
return train_loader
def val_dataloader(self):
# use test data here to get full training curve for test set
val_data = MeshAirfoilDataset(self.hparams.data_dir, mode='test')
self.val_data = val_data
val_loader = DataLoader(val_data,
batch_size=self.hparams.batch_size,
shuffle=True,
num_workers=self.hparams.dataloader_workers)
if self.hparams.verbose:
logging.info(f'Val data: {len(self.val_data)} examples, '
f'{len(val_loader)} batches.')
return val_loader
def test_dataloader(self):
test_data = MeshAirfoilDataset(self.hparams.data_dir, mode='test')
test_loader = DataLoader(test_data,
batch_size=self.hparams.batch_size,
shuffle=False,
num_workers=self.hparams.dataloader_workers)
if self.hparams.verbose:
logging.info(f'Test data: {len(test_data)} examples, '
f'{len(test_loader)} batches.')
return test_loader
def log_images(self, nodes, pred, true, batch, elems_list, mode, log_idx=0):
if self.hparams.no_log or self.logger.debug:
return
inds = batch == log_idx
nodes = nodes[inds]
pred = pred[inds]
true = true[inds]
exp = self.logger.experiment
step = self.trainer.global_step if self.step is None else self.step
for field in range(pred.shape[1]):
true_img = plot_field(nodes, elems_list, true[:, field],
title='true')
true_img = ToTensor()(true_img)
min_max = (true[:, field].min().item(), true[:, field].max().item())
pred_img = plot_field(nodes, elems_list, pred[:, field],
title='pred', clim=min_max)
pred_img = ToTensor()(pred_img)
imgs = [pred_img, true_img]
if hasattr(self.model, 'sim_info'):
sim = self.model.sim_info
sim_inds = sim['batch'] == log_idx
sim_nodes = sim['nodes'][sim_inds]
sim_info = sim['output'][sim_inds]
sim_elems = sim['elems'][log_idx]
mesh_inds = torch.full_like(sim['batch'], fill_value=-1,
dtype=torch.long, device='cpu')
mesh_inds[sim_inds] = torch.arange(sim_nodes.shape[0])
sim_elems_list = self.model.contiguous_elems_list(sim_elems, mesh_inds)
sim_img = plot_field(sim_nodes, sim_elems_list, sim_info[:, field],
title='sim', clim=min_max)
sim_img = ToTensor()(sim_img)
imgs = [sim_img] + imgs
grid = make_grid(torch.stack(imgs), padding=0)
img_name = f'{mode}_pred_f{field}'
exp.add_image(img_name, grid, global_step=step)
def transfer_batch_to_device(self, batch, device):
for k, v in batch:
if hasattr(v, 'to'):
batch[k] = v.to(device)
return batch
@staticmethod
def format_metrics_dict(metrics_dict, exclude='batch'):
return ', '.join(f'{k}: {v:.3}'
for k, v in metrics_dict.items()
if exclude not in k)
@staticmethod
def get_cross_prods(meshes, store_elems):
cross_prods = [is_ccw(mesh[e, :2], ret_val=True)
for mesh, elems in zip(meshes, store_elems) for e in elems]
return cross_prods
if __name__ == '__main__':
activate_su2_mpi(remove_temp_files=True)
args = parse_args()
print(args, file=sys.stderr)
torch.manual_seed(args.seed)
if not args.load_model:
pl_model = LightningWrapper(args)
else:
tags_path = Path(args.load_model).parent / 'meta_tags.csv'
pl_model = LightningWrapper.load_from_checkpoint(args.load_model, tags_csv=tags_path)
logger = False
if not args.no_log:
logger = TestTubeLogger(save_dir='logs',
name=args.exp_name,
debug=args.debug,
version=args.version,
create_git_tag=False)
if not args.debug:
logger.experiment.add_custom_scalars_multilinechart(['train_loss',
'val_loss',
'test_loss'],
title='loss')
checkpoint_callback = None
if not args.debug and not args.no_log:
checkpoint_path = os.path.join(
logger.experiment.get_data_path(logger.name, logger.version),
'checkpoints'
)
checkpoint_callback = ModelCheckpoint(filepath=checkpoint_path,
monitor='val_loss',
mode='auto',
period=1,
save_top_k=10,
save_weights_only=False,
verbose=args.verbose)
early_stop_callback = False
if args.early_stop_patience:
early_stop_callback = EarlyStopping(monitor='val_loss',
mode='auto',
min_delta=0.0,
patience=args.early_stop_patience,
verbose=args.verbose)
trainer = pl.Trainer(logger=logger,
weights_summary='full' if args.verbose else None,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop_callback,
gpus=args.gpus,
distributed_backend=args.distributed_backend,
max_epochs=args.max_epochs * (not args.eval),
val_check_interval=args.val_check_interval,
train_percent_check=args.train_pct,
num_sanity_val_steps=1,
profiler=args.profile,
fast_dev_run=args.debug)
trainer.fit(pl_model)
if args.eval:
trainer.test()