-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfocal_loss.py
94 lines (76 loc) · 3.01 KB
/
focal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
class FocalLoss(nn.Module):
""" Focal Loss, as described in https://arxiv.org/abs/1708.02002.
It is essentially an enhancement to cross entropy loss and is
useful for classification tasks when there is a large class imbalance.
x is expected to contain raw, unnormalized scores for each class.
y is expected to contain class labels.
Shape:
- x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
- y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
"""
def __init__(self,
alpha: Tensor = None,
gamma: float = 0.,
reduction: str = 'mean',
ignore_index: int = -100):
"""Constructor.
Args:
alpha (Tensor): Weights for each class.
gamma (float): A constant, as described in the paper.
reduction (str, optional): 'mean', 'sum' or 'none'.
Defaults to 'mean'.
ignore_index (int, optional): class label to ignore.
"""
super().__init__()
self.gamma = gamma
self.nll_loss = nn.NLLLoss(
weight=alpha, reduction='none', ignore_index=ignore_index)
self.ignore_index = ignore_index
if reduction in ('mean', 'sum', 'none'):
self.reduction = reduction
else:
raise ValueError(
'Reduction must be one of: "mean", "sum", "none".')
def forward(self, x: Tensor, y: Tensor) -> Tensor:
if x.ndim > 2:
# (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
c = x.shape[1]
x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
# (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
y = y.view(-1)
unignored_mask = y != self.ignore_index
y = y[unignored_mask]
if len(y) == 0:
return 0.
x = x[unignored_mask]
# compute weighted cross entropy term: -alpha * log(pt)
log_p = F.log_softmax(x, dim=-1)
ce = self.nll_loss(log_p, y)
# get true class column from each row
all_rows = torch.arange(len(x))
log_pt = log_p[all_rows, y]
# compute focal term: (1 - pt)^gamma
pt = log_pt.exp()
focal_term = (1 - pt)**self.gamma
# the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
loss = focal_term * ce
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
def focal_loss(alpha=None, gamma=0., reduction='mean', ignore_index=-100,
device='cpu', dtype=torch.float32):
if not ((alpha is None) or isinstance(alpha, torch.Tensor)):
alpha = torch.tensor(alpha, device=device, dtype=dtype)
fl = FocalLoss(
alpha=alpha,
gamma=gamma,
reduction=reduction,
ignore_index=ignore_index
)
return fl