Skip to content

Commit

Permalink
add crossentropy and dicece loss wrappers with support for class weights
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Oct 14, 2024
1 parent 1989fd3 commit b71df67
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 22 deletions.
5 changes: 2 additions & 3 deletions configs/vision/radiology/offline/segmentation/lits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ model:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 3
criterion:
class_path: eva.vision.losses.DiceLoss
class_path: eva.core.losses.CrossEntropyLoss
init_args:
softmax: true
batch: true
weight: [0.01, 0.1, 1.5]
optimizer:
class_path: torch.optim.AdamW
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ model:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 3
criterion:
class_path: eva.vision.losses.DiceLoss
class_path: eva.core.losses.CrossEntropyLoss
init_args:
softmax: true
batch: true
weight: [0.05, 0.1, 1.5]
optimizer:
class_path: torch.optim.AdamW
init_args:
Expand Down
5 changes: 2 additions & 3 deletions configs/vision/radiology/online/segmentation/lits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ model:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 3
criterion:
class_path: eva.vision.losses.DiceLoss
class_path: eva.core.losses.CrossEntropyLoss
init_args:
softmax: true
batch: true
weight: [0.01, 0.1, 1.5]
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ model:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 3
criterion:
class_path: eva.vision.losses.DiceLoss
class_path: eva.core.losses.CrossEntropyLoss
init_args:
softmax: true
batch: true
weight: [0.05, 0.1, 1.5]
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
Expand Down
5 changes: 5 additions & 0 deletions src/eva/core/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Loss functions API."""

from eva.core.losses.cross_entropy import CrossEntropyLoss

__all__ = ["CrossEntropyLoss"]
27 changes: 27 additions & 0 deletions src/eva/core/losses/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Cross-entropy based loss function."""

from typing import Sequence

import torch
from torch import nn


class CrossEntropyLoss(nn.CrossEntropyLoss):
"""A wrapper around torch.nn.CrossEntropyLoss that accepts weights in list format.
Needed for .yaml file loading & class instantiation with jsonarparse.
"""

def __init__(
self, *args, weight: Sequence[float] | torch.Tensor | None = None, **kwargs
) -> None:
"""Initialize the loss function.
Args:
args: Positional arguments from the base class.
weight: A list of weights to assign to each class.
kwargs: Key-word arguments from the base class.
"""
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.tensor(weight)
super().__init__(*args, **kwargs, weight=weight)
4 changes: 2 additions & 2 deletions src/eva/vision/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Loss functions API."""

from eva.vision.losses.dice import DiceLoss
from eva.vision.losses.dice import DiceCELoss, DiceLoss

__all__ = ["DiceLoss"]
__all__ = ["DiceLoss", "DiceCELoss"]
83 changes: 75 additions & 8 deletions src/eva/vision/losses/dice.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Dice loss."""
"""Dice based loss functions."""

from typing import Sequence, Tuple

import torch
from monai import losses
Expand All @@ -12,29 +14,94 @@ class DiceLoss(losses.DiceLoss): # type: ignore
Extends the implementation from MONAI
- to support semantic target labels (meaning targets of shape BHW)
- to support `ignore_index` functionality
- accept weight argument in list format
"""

def __init__(self, *args, ignore_index: int | None = None, **kwargs) -> None:
"""Initialize the DiceLoss with support for ignore_index.
def __init__(
self,
*args,
ignore_index: int | None = None,
weight: Sequence[float] | torch.Tensor | None = None,
**kwargs,
) -> None:
"""Initialize the DiceLoss.
Args:
args: Positional arguments from the base class.
ignore_index: Specifies a target value that is ignored and
does not contribute to the input gradient.
weight: A list of weights to assign to each class.
kwargs: Key-word arguments from the base class.
"""
super().__init__(*args, **kwargs)
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.tensor(weight)

super().__init__(*args, **kwargs, weight=weight)

self.ignore_index = ignore_index

@override
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
if self.ignore_index is not None:
mask = targets != self.ignore_index
targets = targets * mask
inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
targets = _to_one_hot(targets, num_classes=inputs.shape[1])

if targets.ndim == 3:
targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])

return super().forward(inputs, targets)


class DiceCELoss(losses.dice.DiceCELoss):
"""Combination of Dice and Cross Entropy Loss.
Extends the implementation from MONAI
- to support semantic target labels (meaning targets of shape BHW)
- to support `ignore_index` functionality
- accept weight argument in list format
"""

def __init__(
self,
*args,
ignore_index: int | None = None,
weight: Sequence[float] | torch.Tensor | None = None,
**kwargs,
) -> None:
"""Initialize the DiceCELoss.
Args:
args: Positional arguments from the base class.
ignore_index: Specifies a target value that is ignored and
does not contribute to the input gradient.
weight: A list of weights to assign to each class.
kwargs: Key-word arguments from the base class.
"""
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.tensor(weight)

super().__init__(*args, **kwargs, weight=weight)

self.ignore_index = ignore_index

@override
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
targets = _to_one_hot(targets, num_classes=inputs.shape[1])

return super().forward(inputs, targets)


def _apply_ignore_index(
inputs: torch.Tensor, targets: torch.Tensor, ignore_index: int | None
) -> Tuple[torch.Tensor, torch.Tensor]:
if ignore_index is not None:
mask = targets != ignore_index
targets = targets * mask
inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
return inputs, targets


def _to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor:
if tensor.ndim == 3:
return one_hot(tensor[:, None, ...], num_classes=num_classes)
return tensor

0 comments on commit b71df67

Please sign in to comment.