Skip to content

Commit

Permalink
Add CrossEntropyLoss and DiceCELoss wrappers with support for cla…
Browse files Browse the repository at this point in the history
…ss weights (#686)
  • Loading branch information
nkaenzig authored Oct 14, 2024
1 parent e192eb9 commit 209b694
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 26 deletions.
9 changes: 5 additions & 4 deletions configs/vision/radiology/offline/segmentation/lits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trainer:
refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1}
- class_path: eva.vision.callbacks.SemanticSegmentationLogger
init_args:
log_every_n_steps: 1000
log_every_n_epochs: 1
log_images: false
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
Expand Down Expand Up @@ -57,10 +57,9 @@ model:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 3
criterion:
class_path: eva.vision.losses.DiceLoss
class_path: eva.core.losses.CrossEntropyLoss
init_args:
softmax: true
batch: true
weight: [0.05, 0.1, 1.5]
optimizer:
class_path: torch.optim.AdamW
init_args:
Expand Down Expand Up @@ -119,6 +118,7 @@ data:
class_path: eva.vision.data.transforms.common.ResizeAndClamp
init_args:
size: ${oc.env:RESIZE_DIM, 224}
clamp_range: [-1008, 822]
mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
- class_path: eva.vision.datasets.LiTS
Expand All @@ -137,6 +137,7 @@ data:
val:
batch_size: *BATCH_SIZE
num_workers: *N_DATA_WORKERS
shuffle: true
test:
batch_size: *BATCH_SIZE
num_workers: *N_DATA_WORKERS
Expand Down
10 changes: 5 additions & 5 deletions configs/vision/radiology/offline/segmentation/lits_balanced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000}
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
Expand All @@ -29,7 +29,7 @@ trainer:
mode: *MONITOR_METRIC_MODE
- class_path: eva.callbacks.SegmentationEmbeddingsWriter
init_args:
output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits
output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced
dataloader_idx_map:
0: train
1: val
Expand Down Expand Up @@ -57,10 +57,9 @@ model:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 3
criterion:
class_path: eva.vision.losses.DiceLoss
class_path: eva.core.losses.CrossEntropyLoss
init_args:
softmax: true
batch: true
weight: [0.05, 0.1, 1.5]
optimizer:
class_path: torch.optim.AdamW
init_args:
Expand Down Expand Up @@ -119,6 +118,7 @@ data:
class_path: eva.vision.data.transforms.common.ResizeAndClamp
init_args:
size: ${oc.env:RESIZE_DIM, 224}
clamp_range: [-1008, 822]
mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
- class_path: eva.vision.datasets.LiTSBalanced
Expand Down
6 changes: 3 additions & 3 deletions configs/vision/radiology/online/segmentation/lits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ model:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 3
criterion:
class_path: eva.vision.losses.DiceLoss
class_path: eva.core.losses.CrossEntropyLoss
init_args:
softmax: true
batch: true
weight: [0.05, 0.1, 1.5]
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
Expand Down Expand Up @@ -96,6 +95,7 @@ data:
class_path: eva.vision.data.transforms.common.ResizeAndClamp
init_args:
size: ${oc.env:RESIZE_DIM, 224}
clamp_range: [-1008, 822]
mean: *NORMALIZE_MEAN
std: *NORMALIZE_STD
val:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000}
log_every_n_steps: 6
callbacks:
Expand Down Expand Up @@ -49,10 +49,9 @@ model:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 3
criterion:
class_path: eva.vision.losses.DiceLoss
class_path: eva.core.losses.CrossEntropyLoss
init_args:
softmax: true
batch: true
weight: [0.05, 0.1, 1.5]
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
Expand Down Expand Up @@ -96,6 +95,7 @@ data:
class_path: eva.vision.data.transforms.common.ResizeAndClamp
init_args:
size: ${oc.env:RESIZE_DIM, 224}
clamp_range: [-1008, 822]
mean: *NORMALIZE_MEAN
std: *NORMALIZE_STD
val:
Expand Down
5 changes: 5 additions & 0 deletions src/eva/core/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Loss functions API."""

from eva.core.losses.cross_entropy import CrossEntropyLoss

__all__ = ["CrossEntropyLoss"]
27 changes: 27 additions & 0 deletions src/eva/core/losses/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Cross-entropy based loss function."""

from typing import Sequence

import torch
from torch import nn


class CrossEntropyLoss(nn.CrossEntropyLoss):
"""A wrapper around torch.nn.CrossEntropyLoss that accepts weights in list format.
Needed for .yaml file loading & class instantiation with jsonarparse.
"""

def __init__(
self, *args, weight: Sequence[float] | torch.Tensor | None = None, **kwargs
) -> None:
"""Initialize the loss function.
Args:
args: Positional arguments from the base class.
weight: A list of weights to assign to each class.
kwargs: Key-word arguments from the base class.
"""
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.tensor(weight)
super().__init__(*args, **kwargs, weight=weight)
4 changes: 2 additions & 2 deletions src/eva/vision/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Loss functions API."""

from eva.vision.losses.dice import DiceLoss
from eva.vision.losses.dice import DiceCELoss, DiceLoss

__all__ = ["DiceLoss"]
__all__ = ["DiceLoss", "DiceCELoss"]
83 changes: 75 additions & 8 deletions src/eva/vision/losses/dice.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Dice loss."""
"""Dice based loss functions."""

from typing import Sequence, Tuple

import torch
from monai import losses
Expand All @@ -12,29 +14,94 @@ class DiceLoss(losses.DiceLoss): # type: ignore
Extends the implementation from MONAI
- to support semantic target labels (meaning targets of shape BHW)
- to support `ignore_index` functionality
- accept weight argument in list format
"""

def __init__(self, *args, ignore_index: int | None = None, **kwargs) -> None:
"""Initialize the DiceLoss with support for ignore_index.
def __init__(
self,
*args,
ignore_index: int | None = None,
weight: Sequence[float] | torch.Tensor | None = None,
**kwargs,
) -> None:
"""Initialize the DiceLoss.
Args:
args: Positional arguments from the base class.
ignore_index: Specifies a target value that is ignored and
does not contribute to the input gradient.
weight: A list of weights to assign to each class.
kwargs: Key-word arguments from the base class.
"""
super().__init__(*args, **kwargs)
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.tensor(weight)

super().__init__(*args, **kwargs, weight=weight)

self.ignore_index = ignore_index

@override
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
if self.ignore_index is not None:
mask = targets != self.ignore_index
targets = targets * mask
inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
targets = _to_one_hot(targets, num_classes=inputs.shape[1])

if targets.ndim == 3:
targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])

return super().forward(inputs, targets)


class DiceCELoss(losses.dice.DiceCELoss):
"""Combination of Dice and Cross Entropy Loss.
Extends the implementation from MONAI
- to support semantic target labels (meaning targets of shape BHW)
- to support `ignore_index` functionality
- accept weight argument in list format
"""

def __init__(
self,
*args,
ignore_index: int | None = None,
weight: Sequence[float] | torch.Tensor | None = None,
**kwargs,
) -> None:
"""Initialize the DiceCELoss.
Args:
args: Positional arguments from the base class.
ignore_index: Specifies a target value that is ignored and
does not contribute to the input gradient.
weight: A list of weights to assign to each class.
kwargs: Key-word arguments from the base class.
"""
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.tensor(weight)

super().__init__(*args, **kwargs, weight=weight)

self.ignore_index = ignore_index

@override
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
targets = _to_one_hot(targets, num_classes=inputs.shape[1])

return super().forward(inputs, targets)


def _apply_ignore_index(
inputs: torch.Tensor, targets: torch.Tensor, ignore_index: int | None
) -> Tuple[torch.Tensor, torch.Tensor]:
if ignore_index is not None:
mask = targets != ignore_index
targets = targets * mask
inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
return inputs, targets


def _to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor:
if tensor.ndim == 3:
return one_hot(tensor[:, None, ...], num_classes=num_classes)
return tensor

0 comments on commit 209b694

Please sign in to comment.