-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add AsymmetricUnifiedFocalLoss for multi-class segmentation
- Loading branch information
Showing
2 changed files
with
164 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
from monai.networks import one_hot | ||
from monai.utils import LossReduction | ||
from torch.nn.modules.loss import _Loss | ||
|
||
|
||
class AsymmetricUnifiedFocalLoss(_Loss): | ||
""" | ||
AsymmetricUnifiedFocalLoss is a variant of Focal Loss. | ||
Implementation of the Asymmetric Unified Focal Loss described in: | ||
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", | ||
Michael Yeung, Computerized Medical Imaging and Graphics | ||
- https://github.com/mlyg/unified-focal-loss/issues/8 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
to_onehot_y: bool = False, | ||
num_classes: int = 2, | ||
gamma: float = 0.75, | ||
delta: float = 0.7, | ||
reduction: LossReduction | str = LossReduction.MEAN, | ||
): | ||
""" | ||
Args: | ||
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||
num_classes : number of classes. Defaults to 2. | ||
delta : weight of the background. Defaults to 0.7. | ||
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||
Example: | ||
>>> import torch | ||
>>> from monai.losses import AsymmetricUnifiedFocalLoss | ||
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32) | ||
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) | ||
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) | ||
>>> fl(pred, grnd) | ||
""" | ||
super().__init__(reduction=LossReduction(reduction).value) | ||
self.to_onehot_y = to_onehot_y | ||
self.num_classes = num_classes | ||
self.gamma = gamma | ||
self.delta = delta | ||
self.epsilon = torch.tensor(1e-7) | ||
|
||
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Args: | ||
y_pred : the shape should be BNH[WD], where N is the number of classes. | ||
It only supports binary segmentation. | ||
The input should be the original logits since it will be transformed by | ||
a sigmoid in the forward function. | ||
y_true : the shape should be BNH[WD], where N is the number of classes. | ||
It only supports binary segmentation. | ||
Raises: | ||
ValueError: When input and target are different shape | ||
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 | ||
ValueError: When num_classes | ||
ValueError: When the number of classes entered does not match the expected number | ||
""" | ||
if y_pred.shape != y_true.shape: | ||
raise ValueError( | ||
f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})" | ||
) | ||
|
||
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: | ||
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") | ||
|
||
if y_pred.shape[1] == 1: | ||
y_pred = one_hot(y_pred, num_classes=self.num_classes) | ||
y_true = one_hot(y_true, num_classes=self.num_classes) | ||
|
||
if torch.max(y_true) != self.num_classes - 1: | ||
raise ValueError( | ||
f"Please make sure the number of classes is {self.num_classes-1}" | ||
) | ||
|
||
# use pytorch CrossEntropyLoss ? | ||
# https://github.com/Project-MONAI/MONAI/blob/2d463a7d19166cff6a83a313f339228bc812912d/monai/losses/dice.py#L741 | ||
epsilon = torch.tensor(self.epsilon, device=y_pred.device) | ||
y_pred = torch.clip(y_pred, epsilon, 1.0 - epsilon) | ||
cross_entropy = -y_true * torch.log(y_pred) | ||
|
||
# calculate losses separately for each class, only enhancing foreground class | ||
back_ce = ( | ||
torch.pow(1 - y_pred[:, 0, ...], self.gamma) * cross_entropy[:, 0, ...] | ||
) | ||
back_ce = (1 - self.delta) * back_ce | ||
|
||
losses = [back_ce] | ||
for i in range(1, self.num_classes): | ||
i_ce = cross_entropy[:, i, ...] | ||
i_ce = self.delta * i_ce | ||
losses.append(i_ce) | ||
|
||
loss = torch.stack(losses) | ||
if self.reduction == LossReduction.SUM.value: | ||
return torch.sum(loss) # sum over the batch and channel dims | ||
if self.reduction == LossReduction.NONE.value: | ||
return loss # returns [N, num_classes] losses | ||
if self.reduction == LossReduction.MEAN.value: | ||
return torch.mean(loss) | ||
raise ValueError( | ||
f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from segmantic.seg.losses import AsymmetricUnifiedFocalLoss | ||
|
||
TEST_CASES = [ | ||
( # shape: (2, 1, 2, 2), (2, 1, 2, 2) | ||
{ | ||
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), | ||
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), | ||
}, | ||
0.0, | ||
), | ||
( # shape: (2, 1, 2, 2), (2, 1, 2, 2) | ||
{ | ||
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), | ||
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), | ||
}, | ||
0.0, | ||
), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("input_data,expected_val", TEST_CASES) | ||
def test_result(input_data, expected_val): | ||
loss = AsymmetricUnifiedFocalLoss() | ||
result = loss(**input_data) | ||
np.testing.assert_allclose( | ||
result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4 | ||
) | ||
|
||
|
||
def test_ill_shape(): | ||
loss = AsymmetricUnifiedFocalLoss() | ||
with pytest.raises(ValueError): | ||
loss(torch.ones((2, 2, 2)), torch.ones((2, 2, 2, 2))) | ||
|
||
|
||
def test_with_cuda(): | ||
loss = AsymmetricUnifiedFocalLoss() | ||
i = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) | ||
j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) | ||
if torch.cuda.is_available(): | ||
i = i.cuda() | ||
j = j.cuda() | ||
output = loss(i, j) | ||
np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_with_cuda() |