-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtrain.py
146 lines (110 loc) · 4.16 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# encoding = utf-8
"""
Author: [email protected]
Github: https://github.com/misads
License: MIT
"""
import os
import pdb
import time
import numpy as np
import torch
from torch import optim
from torch.autograd import Variable
import dataloader.dataloaders as dl
from network import get_model
from eval import evaluate
from options import opt
from utils import init_log, load_meta, save_meta
from mscv.summary import create_summary_writer, write_meters_loss
import misc_utils as utils
# 初始化
with torch.no_grad():
# 初始化路径
save_root = os.path.join(opt.checkpoint_dir, opt.tag)
log_root = os.path.join(opt.log_dir, opt.tag)
utils.try_make_dir(save_root)
utils.try_make_dir(log_root)
# Dataloader
train_dataloader = dl.train_dataloader
val_dataloader = dl.val_dataloader
# 初始化日志
logger = init_log(training=True)
# 初始化训练的meta信息
meta = load_meta(new=True)
save_meta(meta)
# 初始化模型
Model = get_model(opt.model)
model = Model(opt)
# 暂时还不支持多GPU
# if len(opt.gpu_ids):
# model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
model = model.to(device=opt.device)
# 加载预训练模型,恢复中断的训练
if opt.load:
load_epoch = model.load(opt.load)
start_epoch = load_epoch + 1 if opt.resume else 1
else:
start_epoch = 1
# 开始训练
model.train()
# 计算开始和总共的step
print('Start training...')
start_step = (start_epoch - 1) * len(train_dataloader)
global_step = start_step
total_steps = opt.epochs * len(train_dataloader)
start = time.time()
# Tensorboard初始化
writer = create_summary_writer(log_root)
start_time = time.time()
# 在日志记录transforms
logger.info('train_trasforms: ' +str(train_dataloader.dataset.transforms))
logger.info('===========================================')
if val_dataloader is not None:
logger.info('val_trasforms: ' +str(val_dataloader.dataset.transforms))
logger.info('===========================================')
try:
# 训练循环
for epoch in range(start_epoch, opt.epochs + 1):
for iteration, sample in enumerate(train_dataloader):
global_step += 1
# 计算剩余时间
rate = (global_step - start_step) / (time.time() - start)
remaining = (total_steps - global_step) / rate
# 更新模型参数
update_return = model.update(sample)
# 获取当前学习率
lr = model.get_lr()
lr = lr if lr is not None else opt.lr
# 显示进度条
msg = f'lr:{round(lr, 6) : .6f} (loss) {str(model.avg_meters)} ETA: {utils.format_time(remaining)}'
utils.progress_bar(iteration, len(train_dataloader), 'Epoch:%d' % epoch, msg)
# 训练时每1000个step记录一下summary
if global_step % 1000 == 0:
write_meters_loss(writer, 'train', model.avg_meters, global_step)
model.write_train_summary(update_return)
# 每个epoch结束后的显示信息
logger.info(f'Train epoch: {epoch}, lr: {round(lr, 6) : .6f}, (loss) ' + str(model.avg_meters))
if epoch % opt.save_freq == 0 or epoch == opt.epochs: # 最后一个epoch要保存一下
model.save(epoch)
# 训练中验证
if epoch % opt.eval_freq == 0:
model.eval()
eval_result = evaluate(model, val_dataloader, epoch, writer, logger)
model.train()
model.step_scheduler()
meta = load_meta()
meta[-1]['finishtime'] = utils.get_time_stamp()
save_meta(meta)
except Exception as e:
if opt.tag != 'cache':
with open('run_log.txt', 'a') as f:
f.writelines(' Error: ' + str(e)[:120] + '\n')
meta = load_meta()
meta[-1]['finishtime'] = utils.get_time_stamp()
save_meta(meta)
raise Exception('Error') # 再引起一个异常,这样才能打印之前的错误信息
except: # 其他异常,如键盘中断等
meta = load_meta()
meta[-1]['finishtime'] = utils.get_time_stamp()
save_meta(meta)