forked from Scalsol/mega.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
196 lines (178 loc) · 7.4 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import os
import time
import torch
import torch.distributed as dist
from tqdm import tqdm
from mega_core.data import make_data_loader
from mega_core.utils.comm import get_world_size, synchronize
from mega_core.utils.metric_logger import MetricLogger
from mega_core.engine.inference import inference
from apex import amp
def reduce_loss_dict(loss_dict):
"""
Reduce the loss dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
loss_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return loss_dict
with torch.no_grad():
loss_names = []
all_losses = []
for k in sorted(loss_dict.keys()):
loss_names.append(k)
all_losses.append(loss_dict[k])
all_losses = torch.stack(all_losses, dim=0)
dist.reduce(all_losses, dst=0)
if dist.get_rank() == 0:
# only main process gets accumulated, so only divide by
# world_size in this case
all_losses /= world_size
reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
return reduced_losses
def do_train(
cfg,
model,
data_loader,
data_loader_val,
optimizer,
scheduler,
checkpointer,
device,
checkpoint_period,
test_period,
arguments,
):
logger = logging.getLogger("mega_core.trainer")
logger.info("Start training")
meters = MetricLogger(delimiter=" ")
max_iter = len(data_loader)
start_iter = arguments["iteration"]
model.train()
start_training_time = time.time()
end = time.time()
iou_types = ("bbox",)
if cfg.MODEL.MASK_ON:
iou_types = iou_types + ("segm",)
if cfg.MODEL.KEYPOINT_ON:
iou_types = iou_types + ("keypoints",)
dataset_names = cfg.DATASETS.TEST
for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
if any(len(target) < 1 for target in targets):
logger.error(f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" )
continue
data_time = time.time() - end
iteration = iteration + 1
arguments["iteration"] = iteration
if not cfg.MODEL.VID.ENABLE:
images = images.to(device)
else:
method = cfg.MODEL.VID.METHOD
if method in ("base", ):
images = images.to(device)
elif method in ("rdn", "mega", "fgfa", "dff"):
images["cur"] = images["cur"].to(device)
for key in ("ref", "ref_l", "ref_m", "ref_g"):
if key in images.keys():
images[key] = [img.to(device) for img in images[key]]
else:
raise ValueError("method {} not supported yet.".format(method))
targets = [target.to(device) for target in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = reduce_loss_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
meters.update(loss=losses_reduced, **loss_dict_reduced)
optimizer.zero_grad()
# Note: If mixed precision is not used, this ends up doing nothing
# Otherwise apply loss scaling for mixed-precision recipe
with amp.scale_loss(losses, optimizer) as scaled_losses:
scaled_losses.backward()
optimizer.step()
scheduler.step()
batch_time = time.time() - end
end = time.time()
meters.update(time=batch_time, data=data_time)
eta_seconds = meters.time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if iteration % 20 == 0 or iteration == max_iter:
logger.info(
meters.delimiter.join(
[
"eta: {eta}",
"iter: {iter}",
"{meters}",
"lr: {lr:.6f}",
"max mem: {memory:.0f}",
]
).format(
eta=eta_string,
iter=iteration,
meters=str(meters),
lr=optimizer.param_groups[0]["lr"],
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
)
)
if iteration % checkpoint_period == 0:
checkpointer.save("model_{:07d}".format(iteration), **arguments)
if data_loader_val is not None and test_period > 0 and iteration % test_period == 0:
meters_val = MetricLogger(delimiter=" ")
synchronize()
_ = inference( # The result can be used for additional logging, e. g. for TensorBoard
model,
# The method changes the segmentation mask format in a data loader,
# so every time a new data loader is created:
make_data_loader(cfg, is_train=False, is_distributed=(get_world_size() > 1), is_for_period=True),
dataset_name="[Validation]",
iou_types=iou_types,
box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
output_folder=None,
)
synchronize()
model.train()
with torch.no_grad():
# Should be one image for each GPU:
for iteration_val, (images_val, targets_val, _) in enumerate(tqdm(data_loader_val)):
images_val = images_val.to(device)
targets_val = [target.to(device) for target in targets_val]
loss_dict = model(images_val, targets_val)
losses = sum(loss for loss in loss_dict.values())
loss_dict_reduced = reduce_loss_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
meters_val.update(loss=losses_reduced, **loss_dict_reduced)
synchronize()
logger.info(
meters_val.delimiter.join(
[
"[Validation]: ",
"eta: {eta}",
"iter: {iter}",
"{meters}",
"lr: {lr:.6f}",
"max mem: {memory:.0f}",
]
).format(
eta=eta_string,
iter=iteration,
meters=str(meters_val),
lr=optimizer.param_groups[0]["lr"],
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
)
)
if iteration == max_iter:
checkpointer.save("model_final", **arguments)
total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info(
"Total training time: {} ({:.4f} s / it)".format(
total_time_str, total_training_time / (max_iter)
)
)