-
Notifications
You must be signed in to change notification settings - Fork 16
/
main.py
87 lines (69 loc) · 2.42 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@Time : 2019-05-20 16:52
@Author : Wang Xin
@Email : [email protected]
@File : main.py
"""
import os
import random
import numpy as np
import torch
from torch.backends import cudnn
from network import get_model, get_train_params
from options import Options
def main():
opt = Options()
opt.parse_command()
opt.print_items()
# if setting gpu id, the using single GPU
if opt.gpu:
print('Single GPU Mode.')
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
# set random seed
torch.manual_seed(opt.manual_seed)
torch.cuda.manual_seed(opt.manual_seed)
np.random.seed(opt.manual_seed)
random.seed(opt.manual_seed)
cudnn.benchmark = True
if torch.cuda.device_count() > 1:
print('Multi-GPUs Mode.')
print("Let's use ", torch.cuda.device_count(), " GPUs!")
else:
print('Single GPU Mode.')
print("Let's use GPU:", opt.gpu)
if opt.restore:
assert os.path.isfile(opt.restore), \
"=> no checkpoint found at '{}'".format(opt.restore)
print("=> loading checkpoint '{}'".format(opt.restore))
checkpoint = torch.load(opt.restore)
start_iter = checkpoint['epoch'] + 1
best_result = checkpoint['best_result']
optimizer = checkpoint['optimizer']
model = get_model(opt)
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
del checkpoint # clear memory
# del model_dict
torch.cuda.empty_cache()
else:
print("=> creating Model")
model = get_model(opt)
print("=> model created.")
start_iter = 1
best_result = None
# different modules have different learning rate
train_params = get_train_params(opt, model)
optimizer = torch.optim.SGD(train_params, lr=opt.lr, momentum=opt.momentum,
weight_decay=opt.weight_decay)
if torch.cuda.device_count() == 1:
from libs.trainers import single_gpu_trainer
trainer = single_gpu_trainer.trainer(opt, model, optimizer, start_iter, best_result)
trainer.train_eval()
else:
from libs.trainers import multi_gpu_trainer
trainer = multi_gpu_trainer.trainer(opt, model, optimizer, start_iter, best_result)
trainer.train_eval()
if __name__ == '__main__':
main()