-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathtrain_ae.py
111 lines (104 loc) · 5.54 KB
/
train_ae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import time
import argparse
import torch
import tensorflow as tf
from lib.auto_encoder import PointCloudAE
from lib.loss import ChamferLoss
from data.shape_dataset import ShapeDataset
from lib.utils import setup_logger
parser = argparse.ArgumentParser()
parser.add_argument('--num_point', type=int, default=1024, help='number of points, needed if use points')
parser.add_argument('--emb_dim', type=int, default=512, help='dimension of latent embedding [default: 512]')
parser.add_argument('--h5_file', type=str, default='data/obj_models/ShapeNetCore_4096.h5', help='h5 file')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--num_workers', type=int, default=10, help='number of data loading workers')
parser.add_argument('--gpu', type=str, default='0', help='GPU to use')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
parser.add_argument('--start_epoch', type=int, default=1, help='which epoch to start')
parser.add_argument('--max_epoch', type=int, default=50, help='max number of epochs to train')
parser.add_argument('--resume_model', type=str, default='', help='resume from saved model')
parser.add_argument('--result_dir', type=str, default='results/ae_points', help='directory to save train results')
opt = parser.parse_args()
opt.repeat_epoch = 10
opt.decay_step = 5000
opt.decay_rate = [1.0, 0.6, 0.3, 0.1]
def train_net():
# set result directory
if not os.path.exists(opt.result_dir):
os.makedirs(opt.result_dir)
tb_writer = tf.summary.FileWriter(opt.result_dir)
logger = setup_logger('train_log', os.path.join(opt.result_dir, 'log.txt'))
for key, value in vars(opt).items():
logger.info(key + ': ' + str(value))
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu
# model & loss
estimator = PointCloudAE(opt.emb_dim, opt.num_point)
estimator.cuda()
criterion = ChamferLoss()
if opt.resume_model != '':
estimator.load_state_dict(torch.load(opt.resume_model))
# dataset
train_dataset = ShapeDataset(opt.h5_file, mode='train', augment=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size,
shuffle=True, num_workers=opt.num_workers)
val_dataset = ShapeDataset(opt.h5_file, mode='val', augment=False)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=opt.num_workers)
# train
st_time = time.time()
global_step = ((train_dataset.length + opt.batch_size - 1) // opt.batch_size) * opt.repeat_epoch * (opt.start_epoch - 1)
decay_count = -1
for epoch in range(opt.start_epoch, opt.max_epoch+1):
# train one epoch
logger.info('Time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + \
', ' + 'Epoch %02d' % epoch + ', ' + 'Training started'))
# create optimizer and adjust learning rate if needed
if global_step // opt.decay_step > decay_count:
decay_count += 1
if decay_count < len(opt.decay_rate):
current_lr = opt.lr * opt.decay_rate[decay_count]
optimizer = torch.optim.Adam(estimator.parameters(), lr=current_lr)
batch_idx = 0
estimator.train()
for rep in range(opt.repeat_epoch):
for i, data in enumerate(train_dataloader):
# label must be zero_indexed
batch_xyz, batch_label = data
batch_xyz = batch_xyz[:, :, :3].cuda()
optimizer.zero_grad()
embedding, point_cloud = estimator(batch_xyz)
loss, _, _ = criterion(point_cloud, batch_xyz)
summary = tf.Summary(value=[tf.Summary.Value(tag='learning_rate', simple_value=current_lr),
tf.Summary.Value(tag='train_loss', simple_value=loss)])
# backward
loss.backward()
optimizer.step()
global_step += 1
batch_idx += 1
# write results to tensorboard
tb_writer.add_summary(summary, global_step)
if batch_idx % 10 == 0:
logger.info('Batch {0} Loss:{1:f}'.format(batch_idx, loss))
logger.info('>>>>>>>>----------Epoch {:02d} train finish---------<<<<<<<<'.format(epoch))
# evaluate one epoch
logger.info('Time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + \
', ' + 'Epoch %02d' % epoch + ', ' + 'Testing started'))
estimator.eval()
val_loss = 0.0
for i, data in enumerate(val_dataloader, 1):
batch_xyz, batch_label = data
batch_xyz = batch_xyz[:, :, :3].cuda()
embedding, point_cloud = estimator(batch_xyz)
loss, _, _ = criterion(point_cloud, batch_xyz)
val_loss += loss.item()
logger.info('Batch {0} Loss:{1:f}'.format(i, loss))
val_loss = val_loss / i
summary = tf.Summary(value=[tf.Summary.Value(tag='val_loss', simple_value=val_loss)])
tb_writer.add_summary(summary, global_step)
logger.info('Epoch {0:02d} test average loss: {1:06f}'.format(epoch, val_loss))
logger.info('>>>>>>>>----------Epoch {:02d} test finish---------<<<<<<<<'.format(epoch))
# save model after each epoch
torch.save(estimator.state_dict(), '{0}/model_{1:02d}.pth'.format(opt.result_dir, epoch))
if __name__ == '__main__':
train_net()