-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_newgc.py
250 lines (201 loc) · 10.7 KB
/
train_newgc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import os
# import sys
# import logging
import math
from config import assert_and_infer_cfg
from datasetsnew.setup_data import setup_loaders
from mindspore import Model, FixedLossScaleManager, load_checkpoint, load_param_into_net
from mindspore import save_checkpoint, context, load_checkpoint, load_param_into_net
import argparse
#import yaml
import mindspore.nn as nn
from mindspore import Model, FixedLossScaleManager
from mindspore import context
# from mindspore import set_seed
# from mindspore.context import ParallelMode
# from mindspore.communication import init
# from mindspore.nn.optim import optimizer
# from mindspore.train.callback import CheckpointConfig
# from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
# from mindspore.train.callback import TimeMonitor
from mindspore.common import set_seed
from trainonestep import TrainOneStep
from generatorloss import Generatorloss
# Argument Parser
from loss import JointEdgeSegLoss
from network.gscnn import GSCNN
#os.environ['CUDA_VISIBLE_DEVICES'] = '4'
context.set_context(device_id=1)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
#context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
parser = argparse.ArgumentParser(description='GSCNN')
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--arch', type=str, default='network.gscnn.GSCNN')
parser.add_argument('--dataset', type=str, default='cityscapes')
parser.add_argument('--cv', type=int, default=0,
help='cross validation split')
parser.add_argument('--joint_edgeseg_loss', action='store_true', default=True,
help='joint loss')
parser.add_argument('--img_wt_loss', action='store_true', default=False,
help='per-image class-weighted loss')
parser.add_argument('--batch_weighting', action='store_true', default=False,
help='Batch weighting for class')
parser.add_argument('--eval_thresholds', type=str, default='0.0005,0.001875,0.00375,0.005',
help='Thresholds for boundary evaluation')
parser.add_argument('--rescale', type=float, default=1.0,
help='Rescaled LR Rate')
parser.add_argument('--repoly', type=float, default=1.5,
help='Rescaled Poly')
parser.add_argument('--edge_weight', type=float, default=1.0,
help='Edge loss weight for joint loss')
parser.add_argument('--seg_weight', type=float, default=1.0,
help='Segmentation loss weight for joint loss')
parser.add_argument('--att_weight', type=float, default=1.0,
help='Attention loss weight for joint loss')
parser.add_argument('--dual_weight', type=float, default=1.0,
help='Dual loss weight for joint loss')
parser.add_argument('--evaluate', action='store_true', default=False)
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sgd', action='store_true', default=True)
parser.add_argument('--sgd_finetuned',action='store_true',default=False)
parser.add_argument('--adam', action='store_true', default=False)
parser.add_argument('--amsgrad', action='store_true', default=False)
parser.add_argument('--trunk', type=str, default='resnet101',
help='trunk model, can be: resnet101 (default), resnet50')
parser.add_argument('--max_epoch', type=int, default=225)
parser.add_argument('--start_epoch', type=int, default=0)
parser.add_argument('--color_aug', type=float,
default=0.25, help='level of color augmentation')
parser.add_argument('--rotate', type=float,
default=0, help='rotation')
parser.add_argument('--gblur', action='store_true', default=True)
parser.add_argument('--bblur', action='store_true', default=False)
parser.add_argument('--lr_schedule', type=str, default='poly',
help='name of lr schedule: poly')
parser.add_argument('--poly_exp', type=float, default=1.0,
help='polynomial LR exponent')
parser.add_argument('--bs_mult', type=int, default=1)
parser.add_argument('--bs_mult_val', type=int, default=2)
parser.add_argument('--crop_size', type=int, default=720,
help='training crop size')
parser.add_argument('--pre_size', type=int, default=None,
help='resize image shorter edge to this before augmentation')
parser.add_argument('--scale_min', type=float, default=0.5,
help='dynamically scale training images down to this size')
parser.add_argument('--scale_max', type=float, default=2.0,
help='dynamically scale training images up to this size')
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--snapshot', type=str, default=None)
parser.add_argument('--restore_optimizer', action='store_true', default=False)
parser.add_argument('--exp', type=str, default='default',
help='experiment directory name')
parser.add_argument('--tb_tag', type=str, default='',
help='add tag to tb dir')
parser.add_argument('--ckpt', type=str, default='logs/ckpt')
parser.add_argument('--tb_path', type=str, default='logs/tb')
parser.add_argument('--syncbn', action='store_true', default=True,
help='Synchronized BN')
parser.add_argument('--model_type', action='store_true', default="GSCNN")
parser.add_argument('--dump_augmentation_images', action='store_true', default=False,
help='Synchronized BN')
parser.add_argument('--test_mode', action='store_true', default=False,
help='minimum testing (1 epoch run ) to verify nothing failed')
parser.add_argument('-wb', '--wt_bound', type=float, default=1.0)
parser.add_argument('--maxSkip', type=int, default=0)
parser.add_argument('--save_folder', default='/home/heu_MEDAI/gongcheng/New19Channel/xiaocheckpoint',
help='Location to save epoch models')
parser.add_argument('--model_path', type=str,
default="/home/heu_MEDAI/gongcheng/19Channel/xiaocheckpoint/GSCNN_ckpt_57-23_95.ckpt",
help='Path to model file created by training')
args = parser.parse_args()
args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0,
'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
args.world_size = 1
#Test Mode run two epochs with a few iterations of training and val
if args.test_mode:
args.max_epoch = 2
if 'WORLD_SIZE' in os.environ:
args.world_size = int(os.environ['WORLD_SIZE'])
print("Total world size: ", int(os.environ['WORLD_SIZE']))
epoch_loss = []
assert_and_infer_cfg(args)
def train(trainoneStep, data):
trainoneStep.set_train()
trainoneStep.set_grad()
steps = data.get_dataset_size()
for epoch in range(args.max_epoch):
total_loss = 0
j = 0
for iteration, everydata in enumerate(data.create_dict_iterator(), 1):
input = everydata["img"]
mask = everydata["mask"]
edge = everydata["edgemap"]
canny = everydata["canny"]
label_panduan = everydata["label_panduan"]
# print("''''''''''''准备输出loss''''''''''''''''''''''''")
loss = trainoneStep(input, mask, edge, canny, label_panduan)
loss = loss.asnumpy()
# loss = self.network(mixture, len, source, cross_valid)
# print("输出loss: ", loss)
# print("第{}次trainonestp共花费时间:".format(), t1 - t0)
print("epoch[{}]({}/{}),loss:{:.4f},stepTime:{}".format(epoch + 1, j+1, steps, loss.asnumpy()))
# cb_params.cur_step_num = epoch + 1
# ckpt_cb.step_end(run_context)
# if j == (step//2):
# if j % 1001 == 0:
# save_ckpt = os.path.join(args.train_url, 'half{}_{}_gdprnn.ckpt'.format(epoch + 1, j))
# save_checkpoint(trainoneStep.network, save_ckpt)
j = j + 1
total_loss += loss
train_loss = total_loss/j
epoch_loss.append(train_loss)
print("Epoch {} Complete: Avg. Loss: {:.4f}|| Time: {} min {}s.".format(epoch, train_loss))
save_ckpt = os.path.join(args.save_folder, '{}_{}.ckpt'.format(epoch, args.model_type))
save_checkpoint(trainoneStep.network, save_ckpt)
if __name__ == '__main__':
set_seed(42)
train_dataset, val_dataset = setup_loaders(args)
print("分类种类",args.dataset_cls.num_classes)
net = GSCNN(num_classes=args.dataset_cls.num_classes)
print(net)
# param_dict = load_checkpoint(args.model_path)
# load_param_into_net(net, param_dict)
# loss = JointEdgeSegLoss(classes=args.dataset_cls.num_classes, ignore_index=args.dataset_cls.ignore_label, upper_bound=args.wt_bound,
# edge_weight=args.edge_weight, seg_weight=args.seg_weight,
# att_weight=args.att_weight, dual_weight=args.dual_weight)
"""学习率"""
milestone = []
learning_rates = []
for i in range(1, 10):
milestone.append(i)
learning_rates.append(math.pow((1 - i / 175), 1.0)*0.01)
#print(learning_rates)
lr = nn.piecewise_constant_lr(milestone, learning_rates)
def poly_lr(base_lr, epoch_steps, total_steps, end_lr=0.0001, power=0.9):
for i in range(total_steps):
if i % epoch_steps == 0
base_lr = base_lr * 0.98
yield base_lr
"""优化器"""
param_groups = net.trainable_params()
optim = nn.SGD(param_groups, learning_rate=0.01, weight_decay=0.0001,
momentum=0.9, nesterov=False, loss_scale=1.0)
#optim = nn.Adam(param_groups, learning_rate=0.01, weight_decay=args.weight_decay)
#optim = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9)
lossNetwork = Generatorloss(net)
trainonestepNet = TrainOneStep(lossNetwork, optim)
#train(trainonestepNet, train_dataset)
loss_cb = LossMonitor()
time_cb = TimeMonitor()
config_ck = CheckpointConfig(save_checkpoint_steps=95, keep_checkpoint_max=2)
# config_ck = CheckpointConfig(save_checkpoint_steps=num_steps, keep_checkpoint_max=1)^M
ckpt_cb = ModelCheckpoint(prefix='GSCNN_ckpt',
directory=args.save_folder,
config=config_ck)
cb = [loss_cb, ckpt_cb]
#loss_scale = FixedLossScaleManager(0.5, drop_overflow_update=False)^M
#model = Model(net, optimizer=optim, loss_scale_manager=loss_scale)^M
model = Model(trainonestepNet)
model.train(epoch=args.max_epoch, train_dataset=train_dataset, callbacks=cb, dataset_sink_mode=False)