forked from Project-MONAI/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
95 lines (83 loc) · 4.09 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Tuple
import torch
from ignite.engine import Engine
from monai.engines import SupervisedTrainer
from monai.engines.utils import CommonKeys as Keys
from monai.engines.utils import IterationEvents
from torch.nn.parallel import DistributedDataParallel
class DynUNetTrainer(SupervisedTrainer):
"""
This class inherits from SupervisedTrainer in MONAI, and is used with DynUNet
on Decathlon datasets.
"""
def _iteration(self, engine: Engine, batchdata: Dict[str, Any]):
"""
Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
Return below items in a dictionary:
- IMAGE: image Tensor data for model input, already moved to device.
- LABEL: label Tensor data corresponding to the image, already moved to device.
- PRED: prediction result of model.
- LOSS: loss value computed by loss function.
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
Raises:
ValueError: When ``batchdata`` is None.
"""
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
if len(batch) == 2:
inputs, targets = batch
args: Tuple = ()
kwargs: Dict = {}
else:
inputs, targets, args, kwargs = batch
# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
def _compute_pred_loss():
preds = self.inferer(inputs, self.network, *args, **kwargs)
if len(preds.size()) - len(targets.size()) == 1:
# deep supervision mode, need to unbind feature maps first.
preds = torch.unbind(preds, dim=1)
engine.state.output[Keys.PRED] = preds
del preds
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.state.output[Keys.LOSS] = sum(
0.5**i * self.loss_function.forward(p, targets) for i, p in enumerate(engine.state.output[Keys.PRED])
)
engine.fire_event(IterationEvents.LOSS_COMPLETED)
self.network.train()
self.optimizer.zero_grad()
if self.amp and self.scaler is not None:
with torch.cuda.amp.autocast():
_compute_pred_loss()
self.scaler.scale(engine.state.output[Keys.LOSS]).backward()
self.scaler.unscale_(self.optimizer)
if isinstance(self.network, DistributedDataParallel):
torch.nn.utils.clip_grad_norm_(self.network.module.parameters(), 12)
else:
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
_compute_pred_loss()
engine.state.output[Keys.LOSS].backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if isinstance(self.network, DistributedDataParallel):
torch.nn.utils.clip_grad_norm_(self.network.module.parameters(), 12)
else:
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.optimizer.step()
engine.fire_event(IterationEvents.MODEL_COMPLETED)
return engine.state.output