-
Notifications
You must be signed in to change notification settings - Fork 181
/
train.py
483 lines (373 loc) · 21.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
import os
import shutil
import argparse
import time
import json
from datetime import datetime
from collections import defaultdict
from itertools import islice
import pickle
import copy
import numpy as np
import cv2
import torch
from torch import nn
from torch import autograd
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
from tensorboardX import SummaryWriter
from mvn.models.triangulation import RANSACTriangulationNet, AlgebraicTriangulationNet, VolumetricTriangulationNet
from mvn.models.loss import KeypointsMSELoss, KeypointsMSESmoothLoss, KeypointsMAELoss, KeypointsL2Loss, VolumetricCELoss
from mvn.utils import img, multiview, op, vis, misc, cfg
from mvn.datasets import human36m
from mvn.datasets import utils as dataset_utils
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path, where config file is stored")
parser.add_argument('--eval', action='store_true', help="If set, then only evaluation will be done")
parser.add_argument('--eval_dataset', type=str, default='val', help="Dataset split on which evaluate. Can be 'train' and 'val'")
parser.add_argument("--local_rank", type=int, help="Local rank of the process on the node")
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
parser.add_argument("--logdir", type=str, default="/Vol1/dbstore/datasets/k.iskakov/logs/multi-view-net-repr", help="Path, where logs will be stored")
args = parser.parse_args()
return args
def setup_human36m_dataloaders(config, is_train, distributed_train):
train_dataloader = None
if is_train:
# train
train_dataset = human36m.Human36MMultiViewDataset(
h36m_root=config.dataset.train.h36m_root,
pred_results_path=config.dataset.train.pred_results_path if hasattr(config.dataset.train, "pred_results_path") else None,
train=True,
test=False,
image_shape=config.image_shape if hasattr(config, "image_shape") else (256, 256),
labels_path=config.dataset.train.labels_path,
with_damaged_actions=config.dataset.train.with_damaged_actions,
scale_bbox=config.dataset.train.scale_bbox,
kind=config.kind,
undistort_images=config.dataset.train.undistort_images,
ignore_cameras=config.dataset.train.ignore_cameras if hasattr(config.dataset.train, "ignore_cameras") else [],
crop=config.dataset.train.crop if hasattr(config.dataset.train, "crop") else True,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed_train else None
train_dataloader = DataLoader(
train_dataset,
batch_size=config.opt.batch_size,
shuffle=config.dataset.train.shuffle and (train_sampler is None), # debatable
sampler=train_sampler,
collate_fn=dataset_utils.make_collate_fn(randomize_n_views=config.dataset.train.randomize_n_views,
min_n_views=config.dataset.train.min_n_views,
max_n_views=config.dataset.train.max_n_views),
num_workers=config.dataset.train.num_workers,
worker_init_fn=dataset_utils.worker_init_fn,
pin_memory=True
)
# val
val_dataset = human36m.Human36MMultiViewDataset(
h36m_root=config.dataset.val.h36m_root,
pred_results_path=config.dataset.val.pred_results_path if hasattr(config.dataset.val, "pred_results_path") else None,
train=False,
test=True,
image_shape=config.image_shape if hasattr(config, "image_shape") else (256, 256),
labels_path=config.dataset.val.labels_path,
with_damaged_actions=config.dataset.val.with_damaged_actions,
retain_every_n_frames_in_test=config.dataset.val.retain_every_n_frames_in_test,
scale_bbox=config.dataset.val.scale_bbox,
kind=config.kind,
undistort_images=config.dataset.val.undistort_images,
ignore_cameras=config.dataset.val.ignore_cameras if hasattr(config.dataset.val, "ignore_cameras") else [],
crop=config.dataset.val.crop if hasattr(config.dataset.val, "crop") else True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config.opt.val_batch_size if hasattr(config.opt, "val_batch_size") else config.opt.batch_size,
shuffle=config.dataset.val.shuffle,
collate_fn=dataset_utils.make_collate_fn(randomize_n_views=config.dataset.val.randomize_n_views,
min_n_views=config.dataset.val.min_n_views,
max_n_views=config.dataset.val.max_n_views),
num_workers=config.dataset.val.num_workers,
worker_init_fn=dataset_utils.worker_init_fn,
pin_memory=True
)
return train_dataloader, val_dataloader, train_sampler
def setup_dataloaders(config, is_train=True, distributed_train=False):
if config.dataset.kind == 'human36m':
train_dataloader, val_dataloader, train_sampler = setup_human36m_dataloaders(config, is_train, distributed_train)
else:
raise NotImplementedError("Unknown dataset: {}".format(config.dataset.kind))
return train_dataloader, val_dataloader, train_sampler
def setup_experiment(config, model_name, is_train=True):
prefix = "" if is_train else "eval_"
if config.title:
experiment_title = config.title + "_" + model_name
else:
experiment_title = model_name
experiment_title = prefix + experiment_title
experiment_name = '{}@{}'.format(experiment_title, datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))
print("Experiment name: {}".format(experiment_name))
experiment_dir = os.path.join(args.logdir, experiment_name)
os.makedirs(experiment_dir, exist_ok=True)
checkpoints_dir = os.path.join(experiment_dir, "checkpoints")
os.makedirs(checkpoints_dir, exist_ok=True)
shutil.copy(args.config, os.path.join(experiment_dir, "config.yaml"))
# tensorboard
writer = SummaryWriter(os.path.join(experiment_dir, "tb"))
# dump config to tensorboard
writer.add_text(misc.config_to_str(config), "config", 0)
return experiment_dir, writer
def one_epoch(model, criterion, opt, config, dataloader, device, epoch, n_iters_total=0, is_train=True, caption='', master=False, experiment_dir=None, writer=None):
name = "train" if is_train else "val"
model_type = config.model.name
if is_train:
model.train()
else:
model.eval()
metric_dict = defaultdict(list)
results = defaultdict(list)
# used to turn on/off gradients
grad_context = torch.autograd.enable_grad if is_train else torch.no_grad
with grad_context():
end = time.time()
iterator = enumerate(dataloader)
if is_train and config.opt.n_iters_per_epoch is not None:
iterator = islice(iterator, config.opt.n_iters_per_epoch)
for iter_i, batch in iterator:
with autograd.detect_anomaly():
# measure data loading time
data_time = time.time() - end
if batch is None:
print("Found None batch")
continue
images_batch, keypoints_3d_gt, keypoints_3d_validity_gt, proj_matricies_batch = dataset_utils.prepare_batch(batch, device, config)
keypoints_2d_pred, cuboids_pred, base_points_pred = None, None, None
if model_type == "alg" or model_type == "ransac":
keypoints_3d_pred, keypoints_2d_pred, heatmaps_pred, confidences_pred = model(images_batch, proj_matricies_batch, batch)
elif model_type == "vol":
keypoints_3d_pred, heatmaps_pred, volumes_pred, confidences_pred, cuboids_pred, coord_volumes_pred, base_points_pred = model(images_batch, proj_matricies_batch, batch)
batch_size, n_views, image_shape = images_batch.shape[0], images_batch.shape[1], tuple(images_batch.shape[3:])
n_joints = keypoints_3d_pred.shape[1]
keypoints_3d_binary_validity_gt = (keypoints_3d_validity_gt > 0.0).type(torch.float32)
scale_keypoints_3d = config.opt.scale_keypoints_3d if hasattr(config.opt, "scale_keypoints_3d") else 1.0
# 1-view case
if n_views == 1:
if config.kind == "human36m":
base_joint = 6
elif config.kind == "coco":
base_joint = 11
keypoints_3d_gt_transformed = keypoints_3d_gt.clone()
keypoints_3d_gt_transformed[:, torch.arange(n_joints) != base_joint] -= keypoints_3d_gt_transformed[:, base_joint:base_joint + 1]
keypoints_3d_gt = keypoints_3d_gt_transformed
keypoints_3d_pred_transformed = keypoints_3d_pred.clone()
keypoints_3d_pred_transformed[:, torch.arange(n_joints) != base_joint] -= keypoints_3d_pred_transformed[:, base_joint:base_joint + 1]
keypoints_3d_pred = keypoints_3d_pred_transformed
# calculate loss
total_loss = 0.0
loss = criterion(keypoints_3d_pred * scale_keypoints_3d, keypoints_3d_gt * scale_keypoints_3d, keypoints_3d_binary_validity_gt)
total_loss += loss
metric_dict[f'{config.opt.criterion}'].append(loss.item())
# volumetric ce loss
use_volumetric_ce_loss = config.opt.use_volumetric_ce_loss if hasattr(config.opt, "use_volumetric_ce_loss") else False
if use_volumetric_ce_loss:
volumetric_ce_criterion = VolumetricCELoss()
loss = volumetric_ce_criterion(coord_volumes_pred, volumes_pred, keypoints_3d_gt, keypoints_3d_binary_validity_gt)
metric_dict['volumetric_ce_loss'].append(loss.item())
weight = config.opt.volumetric_ce_loss_weight if hasattr(config.opt, "volumetric_ce_loss_weight") else 1.0
total_loss += weight * loss
metric_dict['total_loss'].append(total_loss.item())
if is_train:
opt.zero_grad()
total_loss.backward()
if hasattr(config.opt, "grad_clip"):
torch.nn.utils.clip_grad_norm_(model.parameters(), config.opt.grad_clip / config.opt.lr)
metric_dict['grad_norm_times_lr'].append(config.opt.lr * misc.calc_gradient_norm(filter(lambda x: x[1].requires_grad, model.named_parameters())))
opt.step()
# calculate metrics
l2 = KeypointsL2Loss()(keypoints_3d_pred * scale_keypoints_3d, keypoints_3d_gt * scale_keypoints_3d, keypoints_3d_binary_validity_gt)
metric_dict['l2'].append(l2.item())
# base point l2
if base_points_pred is not None:
base_point_l2_list = []
for batch_i in range(batch_size):
base_point_pred = base_points_pred[batch_i]
if config.model.kind == "coco":
base_point_gt = (keypoints_3d_gt[batch_i, 11, :3] + keypoints_3d[batch_i, 12, :3]) / 2
elif config.model.kind == "mpii":
base_point_gt = keypoints_3d_gt[batch_i, 6, :3]
base_point_l2_list.append(torch.sqrt(torch.sum((base_point_pred * scale_keypoints_3d - base_point_gt * scale_keypoints_3d) ** 2)).item())
base_point_l2 = 0.0 if len(base_point_l2_list) == 0 else np.mean(base_point_l2_list)
metric_dict['base_point_l2'].append(base_point_l2)
# save answers for evalulation
if not is_train:
results['keypoints_3d'].append(keypoints_3d_pred.detach().cpu().numpy())
results['indexes'].append(batch['indexes'])
# plot visualization
if master:
if n_iters_total % config.vis_freq == 0:# or total_l2.item() > 500.0:
vis_kind = config.kind
if (config.transfer_cmu_to_human36m if hasattr(config, "transfer_cmu_to_human36m") else False):
vis_kind = "coco"
for batch_i in range(min(batch_size, config.vis_n_elements)):
keypoints_vis = vis.visualize_batch(
images_batch, heatmaps_pred, keypoints_2d_pred, proj_matricies_batch,
keypoints_3d_gt, keypoints_3d_pred,
kind=vis_kind,
cuboids_batch=cuboids_pred,
confidences_batch=confidences_pred,
batch_index=batch_i, size=5,
max_n_cols=10
)
writer.add_image(f"{name}/keypoints_vis/{batch_i}", keypoints_vis.transpose(2, 0, 1), global_step=n_iters_total)
heatmaps_vis = vis.visualize_heatmaps(
images_batch, heatmaps_pred,
kind=vis_kind,
batch_index=batch_i, size=5,
max_n_rows=10, max_n_cols=10
)
writer.add_image(f"{name}/heatmaps/{batch_i}", heatmaps_vis.transpose(2, 0, 1), global_step=n_iters_total)
if model_type == "vol":
volumes_vis = vis.visualize_volumes(
images_batch, volumes_pred, proj_matricies_batch,
kind=vis_kind,
cuboids_batch=cuboids_pred,
batch_index=batch_i, size=5,
max_n_rows=1, max_n_cols=16
)
writer.add_image(f"{name}/volumes/{batch_i}", volumes_vis.transpose(2, 0, 1), global_step=n_iters_total)
# dump weights to tensoboard
if n_iters_total % config.vis_freq == 0:
for p_name, p in model.named_parameters():
try:
writer.add_histogram(p_name, p.clone().cpu().data.numpy(), n_iters_total)
except ValueError as e:
print(e)
print(p_name, p)
exit()
# dump to tensorboard per-iter loss/metric stats
if is_train:
for title, value in metric_dict.items():
writer.add_scalar(f"{name}/{title}", value[-1], n_iters_total)
# measure elapsed time
batch_time = time.time() - end
end = time.time()
# dump to tensorboard per-iter time stats
writer.add_scalar(f"{name}/batch_time", batch_time, n_iters_total)
writer.add_scalar(f"{name}/data_time", data_time, n_iters_total)
# dump to tensorboard per-iter stats about sizes
writer.add_scalar(f"{name}/batch_size", batch_size, n_iters_total)
writer.add_scalar(f"{name}/n_views", n_views, n_iters_total)
n_iters_total += 1
# calculate evaluation metrics
if master:
if not is_train:
results['keypoints_3d'] = np.concatenate(results['keypoints_3d'], axis=0)
results['indexes'] = np.concatenate(results['indexes'])
try:
scalar_metric, full_metric = dataloader.dataset.evaluate(results['keypoints_3d'])
except Exception as e:
print("Failed to evaluate. Reason: ", e)
scalar_metric, full_metric = 0.0, {}
metric_dict['dataset_metric'].append(scalar_metric)
checkpoint_dir = os.path.join(experiment_dir, "checkpoints", "{:04}".format(epoch))
os.makedirs(checkpoint_dir, exist_ok=True)
# dump results
with open(os.path.join(checkpoint_dir, "results.pkl"), 'wb') as fout:
pickle.dump(results, fout)
# dump full metric
with open(os.path.join(checkpoint_dir, "metric.json".format(epoch)), 'w') as fout:
json.dump(full_metric, fout, indent=4, sort_keys=True)
# dump to tensorboard per-epoch stats
for title, value in metric_dict.items():
writer.add_scalar(f"{name}/{title}_epoch", np.mean(value), epoch)
return n_iters_total
def init_distributed(args):
if "WORLD_SIZE" not in os.environ or int(os.environ["WORLD_SIZE"]) < 1:
return False
torch.cuda.set_device(args.local_rank)
assert os.environ["MASTER_PORT"], "set the MASTER_PORT variable or use pytorch launcher"
assert os.environ["RANK"], "use pytorch launcher and explicityly state the rank of the process"
torch.manual_seed(args.seed)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
return True
def main(args):
print("Number of available GPUs: {}".format(torch.cuda.device_count()))
is_distributed = init_distributed(args)
master = True
if is_distributed and os.environ["RANK"]:
master = int(os.environ["RANK"]) == 0
if is_distributed:
device = torch.device(args.local_rank)
else:
device = torch.device(0)
# config
config = cfg.load_config(args.config)
config.opt.n_iters_per_epoch = config.opt.n_objects_per_epoch // config.opt.batch_size
model = {
"ransac": RANSACTriangulationNet,
"alg": AlgebraicTriangulationNet,
"vol": VolumetricTriangulationNet
}[config.model.name](config, device=device).to(device)
if config.model.init_weights:
state_dict = torch.load(config.model.checkpoint)
for key in list(state_dict.keys()):
new_key = key.replace("module.", "")
state_dict[new_key] = state_dict.pop(key)
model.load_state_dict(state_dict, strict=True)
print("Successfully loaded pretrained weights for whole model")
# criterion
criterion_class = {
"MSE": KeypointsMSELoss,
"MSESmooth": KeypointsMSESmoothLoss,
"MAE": KeypointsMAELoss
}[config.opt.criterion]
if config.opt.criterion == "MSESmooth":
criterion = criterion_class(config.opt.mse_smooth_threshold)
else:
criterion = criterion_class()
# optimizer
opt = None
if not args.eval:
if config.model.name == "vol":
opt = torch.optim.Adam(
[{'params': model.backbone.parameters()},
{'params': model.process_features.parameters(), 'lr': config.opt.process_features_lr if hasattr(config.opt, "process_features_lr") else config.opt.lr},
{'params': model.volume_net.parameters(), 'lr': config.opt.volume_net_lr if hasattr(config.opt, "volume_net_lr") else config.opt.lr}
],
lr=config.opt.lr
)
else:
opt = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.opt.lr)
# datasets
print("Loading data...")
train_dataloader, val_dataloader, train_sampler = setup_dataloaders(config, distributed_train=is_distributed)
# experiment
experiment_dir, writer = None, None
if master:
experiment_dir, writer = setup_experiment(config, type(model).__name__, is_train=not args.eval)
# multi-gpu
if is_distributed:
model = DistributedDataParallel(model, device_ids=[device])
if not args.eval:
# train loop
n_iters_total_train, n_iters_total_val = 0, 0
for epoch in range(config.opt.n_epochs):
if train_sampler is not None:
train_sampler.set_epoch(epoch)
n_iters_total_train = one_epoch(model, criterion, opt, config, train_dataloader, device, epoch, n_iters_total=n_iters_total_train, is_train=True, master=master, experiment_dir=experiment_dir, writer=writer)
n_iters_total_val = one_epoch(model, criterion, opt, config, val_dataloader, device, epoch, n_iters_total=n_iters_total_val, is_train=False, master=master, experiment_dir=experiment_dir, writer=writer)
if master:
checkpoint_dir = os.path.join(experiment_dir, "checkpoints", "{:04}".format(epoch))
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(checkpoint_dir, "weights.pth"))
print(f"{n_iters_total_train} iters done.")
else:
if args.eval_dataset == 'train':
one_epoch(model, criterion, opt, config, train_dataloader, device, 0, n_iters_total=0, is_train=False, master=master, experiment_dir=experiment_dir, writer=writer)
else:
one_epoch(model, criterion, opt, config, val_dataloader, device, 0, n_iters_total=0, is_train=False, master=master, experiment_dir=experiment_dir, writer=writer)
print("Done.")
if __name__ == '__main__':
args = parse_args()
print("args: {}".format(args))
main(args)