forked from rwightman/efficientdet-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
197 lines (162 loc) · 8.12 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List
def focal_loss(logits, targets, alpha: float, gamma: float, normalizer):
"""Compute the focal loss between `logits` and the golden `target` values.
Focal loss = -(1-pt)^gamma * log(pt)
where pt is the probability of being classified to the true class.
Args:
logits: A float32 tensor of size [batch, height_in, width_in, num_predictions].
targets: A float32 tensor of size [batch, height_in, width_in, num_predictions].
alpha: A float32 scalar multiplying alpha to the loss from positive examples
and (1-alpha) to the loss from negative examples.
gamma: A float32 scalar modulating loss from hard and easy examples.
normalizer: A float32 scalar normalizes the total loss from all examples.
Returns:
loss: A float32 scalar representing normalized total loss.
"""
positive_label_mask = targets == 1.0
cross_entropy = F.binary_cross_entropy_with_logits(logits, targets.to(logits.dtype), reduction='none')
# Below are comments/derivations for computing modulator.
# For brevity, let x = logits, z = targets, r = gamma, and p_t = sigmod(x)
# for positive samples and 1 - sigmoid(x) for negative examples.
#
# The modulator, defined as (1 - P_t)^r, is a critical part in focal loss
# computation. For r > 0, it puts more weights on hard examples, and less
# weights on easier ones. However if it is directly computed as (1 - P_t)^r,
# its back-propagation is not stable when r < 1. The implementation here
# resolves the issue.
#
# For positive samples (labels being 1),
# (1 - p_t)^r
# = (1 - sigmoid(x))^r
# = (1 - (1 / (1 + exp(-x))))^r
# = (exp(-x) / (1 + exp(-x)))^r
# = exp(log((exp(-x) / (1 + exp(-x)))^r))
# = exp(r * log(exp(-x)) - r * log(1 + exp(-x)))
# = exp(- r * x - r * log(1 + exp(-x)))
#
# For negative samples (labels being 0),
# (1 - p_t)^r
# = (sigmoid(x))^r
# = (1 / (1 + exp(-x)))^r
# = exp(log((1 / (1 + exp(-x)))^r))
# = exp(-r * log(1 + exp(-x)))
#
# Therefore one unified form for positive (z = 1) and negative (z = 0)
# samples is:
# (1 - p_t)^r = exp(-r * z * x - r * log(1 + exp(-x))).
neg_logits = -1.0 * logits
modulator = torch.exp(gamma * targets * neg_logits - gamma * torch.log1p(torch.exp(neg_logits)))
loss = modulator * cross_entropy
weighted_loss = torch.where(positive_label_mask, alpha * loss, (1.0 - alpha) * loss)
weighted_loss /= normalizer
return weighted_loss
def huber_loss(
input, target, delta: float = 1., weights: Optional[torch.Tensor] = None, size_average: bool = True):
"""
"""
err = input - target
abs_err = err.abs()
quadratic = torch.clamp(abs_err, max=delta)
linear = abs_err - quadratic
loss = 0.5 * quadratic.pow(2) + delta * linear
if weights is not None:
loss *= weights
return loss.mean() if size_average else loss.sum()
def smooth_l1_loss(
input, target, beta: float = 1. / 9, weights: Optional[torch.Tensor] = None, size_average: bool = True):
"""
very similar to the smooth_l1_loss from pytorch, but with the extra beta parameter
"""
if beta < 1e-5:
# if beta == 0, then torch.where will result in nan gradients when
# the chain rule is applied due to pytorch implementation details
# (the False branch "0.5 * n ** 2 / 0" has an incoming gradient of
# zeros, rather than "no gradient"). To avoid this issue, we define
# small values of beta to be exactly l1 loss.
loss = torch.abs(input - target)
else:
err = torch.abs(input - target)
loss = torch.where(err < beta, 0.5 * err.pow(2) / beta, err - 0.5 * beta)
if weights is not None:
loss *= weights
return loss.mean() if size_average else loss.sum()
def _classification_loss(cls_outputs, cls_targets, num_positives, alpha: float = 0.25, gamma: float = 2.0):
"""Computes classification loss."""
normalizer = num_positives
classification_loss = focal_loss(cls_outputs, cls_targets, alpha, gamma, normalizer)
return classification_loss
def _box_loss(box_outputs, box_targets, num_positives, delta: float = 0.1):
"""Computes box regression loss."""
# delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P3-P7 pyramid is about [0.1, 0.1, 0.2, 0.2].
normalizer = num_positives * 4.0
mask = box_targets != 0.0
box_loss = huber_loss(box_outputs, box_targets, weights=mask, delta=delta, size_average=False)
box_loss /= normalizer
return box_loss
class DetectionLoss(nn.Module):
def __init__(self, config):
super(DetectionLoss, self).__init__()
self.config = config
self.num_classes = config.num_classes
self.alpha = config.alpha
self.gamma = config.gamma
self.delta = config.delta
self.box_loss_weight = config.box_loss_weight
def forward(
self, cls_outputs: List[torch.Tensor], box_outputs: List[torch.Tensor],
cls_targets: List[torch.Tensor], box_targets: List[torch.Tensor], num_positives: torch.Tensor):
"""Computes total detection loss.
Computes total detection loss including box and class loss from all levels.
Args:
cls_outputs: a List with values representing logits in [batch_size, height, width, num_anchors].
at each feature level (index)
box_outputs: a List with values representing box regression targets in
[batch_size, height, width, num_anchors * 4] at each feature level (index)
cls_targets: groundtruth class targets.
box_targets: groundtrusth box targets.
num_positives: num positive grountruth anchors
Returns:
total_loss: an integer tensor representing total loss reducing from class and box losses from all levels.
cls_loss: an integer tensor representing total class loss.
box_loss: an integer tensor representing total box regression loss.
"""
# Sum all positives in a batch for normalization and avoid zero
# num_positives_sum, which would lead to inf loss during training
num_positives_sum = num_positives.sum() + 1.0
levels = len(cls_outputs)
cls_losses = []
box_losses = []
for l in range(levels):
cls_targets_at_level = cls_targets[l]
box_targets_at_level = box_targets[l]
# Onehot encoding for classification labels.
# NOTE: PyTorch one-hot does not handle -ve entries (no hot) like Tensorflow, so mask them out
cls_targets_non_neg = cls_targets_at_level >= 0
cls_targets_at_level_oh = F.one_hot(cls_targets_at_level * cls_targets_non_neg, self.num_classes)
cls_targets_at_level_oh = torch.where(
cls_targets_non_neg.unsqueeze(-1), cls_targets_at_level_oh, torch.zeros_like(cls_targets_at_level_oh))
bs, height, width, _, _ = cls_targets_at_level_oh.shape
cls_targets_at_level_oh = cls_targets_at_level_oh.view(bs, height, width, -1)
cls_loss = _classification_loss(
cls_outputs[l].permute(0, 2, 3, 1),
cls_targets_at_level_oh,
num_positives_sum,
alpha=self.alpha, gamma=self.gamma)
cls_loss = cls_loss.view(bs, height, width, -1, self.num_classes)
cls_loss *= (cls_targets_at_level != -2).unsqueeze(-1).float()
cls_losses.append(cls_loss.sum())
box_losses.append(_box_loss(
box_outputs[l].permute(0, 2, 3, 1),
box_targets_at_level,
num_positives_sum,
delta=self.delta))
# Sum per level losses to total loss.
cls_loss = torch.sum(torch.stack(cls_losses, dim=-1), dim=-1)
box_loss = torch.sum(torch.stack(box_losses, dim=-1), dim=-1)
total_loss = cls_loss + self.box_loss_weight * box_loss
return total_loss, cls_loss, box_loss