-
Notifications
You must be signed in to change notification settings - Fork 5
/
finetune_pl.py
155 lines (139 loc) · 4.34 KB
/
finetune_pl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import pathlib
import time
import click
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, random_split
from datasets import ModelNet40
from executor import MeshDataEncoderPL
@click.command()
@click.option('--train_dataset', help='The training dataset file path')
@click.option(
'--split_ratio',
default=0.8,
help='The proportion of training samples out of the whole training dataset',
)
@click.option('--eval_dataset', help='The evaluation dataset file path')
@click.option(
'--embed_dim', default=512, help='The embedding dimension of the final outputs'
)
@click.option('--hidden_dim', default=1024, help='The dimension of the used models')
@click.option(
'--checkpoint_path',
type=click.Path(file_okay=True, path_type=pathlib.Path),
help='The path of checkpoint',
)
@click.option(
'--output_path',
type=click.Path(file_okay=True, path_type=pathlib.Path),
help='The path of output files',
)
@click.option(
'--model_name',
default='pointnet',
type=click.Choice(
['pointnet', 'pointnet2', 'curvenet', 'pointmlp', 'pointconv', 'repsurf']
),
help='The model name',
)
@click.option('--batch_size', default=128, help='The size of each batch')
@click.option('--epochs', default=50, help='The epochs of training process')
@click.option('--use-gpu/--no-use-gpu', default=True, help='If True to use gpu')
@click.option(
'--interactive', default=False, help='set to True if you have unlabeled data'
)
@click.option(
'--devices', default=7, help='The number of gpus/tpus you can use for training'
)
@click.option('--seed', default=10, help='The random seed for reproducing results')
def main(
train_dataset,
split_ratio,
eval_dataset,
model_name,
embed_dim,
hidden_dim,
batch_size,
epochs,
use_gpu,
checkpoint_path,
output_path,
interactive,
devices,
seed,
):
seed = int(time.time())
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if use_gpu:
device = 'cuda'
else:
device = 'cpu'
if checkpoint_path:
model = MeshDataEncoderPL.load_from_checkpoint(
checkpoint_path, map_location=device
)
else:
model = MeshDataEncoderPL(
default_model_name=model_name,
embed_dim=embed_dim,
device=device,
hidden_dim=hidden_dim,
batch_size=batch_size,
)
train_and_val_data = ModelNet40(train_dataset, seed=seed)
tot_len = len(train_and_val_data)
train_len = int(tot_len * split_ratio)
validate_len = tot_len - train_len
train_data, validate_data = random_split(
train_and_val_data, [train_len, validate_len]
)
test_data = ModelNet40(eval_dataset, seed=seed)
# drop_last=True, avoid batch=1 error from BatchNorm
train_loader = DataLoader(
train_data, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True
)
validate_loader = DataLoader(
validate_data,
batch_size=batch_size,
shuffle=False,
num_workers=8,
drop_last=True,
)
test_loader = DataLoader(
test_data, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=True
)
logger = TensorBoardLogger(
save_dir='./finetune_logs' if output_path is None else output_path,
log_graph=True,
name='{}_hidden_{}_embed_{}_batch_{}_epochs_{}_seed_{}'.format(
model_name, hidden_dim, embed_dim, batch_size, epochs, seed
),
)
checkpoint_callback = ModelCheckpoint(
save_top_k=5,
monitor='val_loss',
mode='min',
filename='{epoch:02d}-{val_loss:.2f}',
)
trainer = Trainer(
accelerator='gpu' if use_gpu else 'cpu',
devices=devices,
max_epochs=epochs,
check_val_every_n_epoch=1,
enable_checkpointing=True,
logger=logger,
callbacks=[checkpoint_callback],
)
model.train()
trainer.fit(model, train_loader, validate_loader)
print(checkpoint_callback.best_model_path)
model.eval()
print('Validation set:')
trainer.test(model, dataloaders=validate_loader)
print('Testing set:')
trainer.test(model, dataloaders=test_loader)
if __name__ == '__main__':
main()