Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct behavior of ignore_index for JaccardLoss #1898

Merged
merged 11 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/conf/chesapeake_cvpr_5.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
loss: "jaccard"
model: "unet"
backbone: "resnet50"
in_channels: 4
num_classes: 5
num_filters: 1
ignore_index: null
ignore_index: 0
data:
class_path: ChesapeakeCVPRDataModule
init_args:
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/chesapeake_cvpr_7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ model:
in_channels: 4
num_classes: 7
num_filters: 1
ignore_index: null
ignore_index: 0
data:
class_path: ChesapeakeCVPRDataModule
init_args:
Expand Down
5 changes: 0 additions & 5 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,6 @@ def test_invalid_loss(self) -> None:
with pytest.raises(ValueError, match=match):
SemanticSegmentationTask(loss="invalid_loss")

def test_ignoreindex_with_jaccard(self) -> None:
match = "ignore_index has no effect on training when loss='jaccard'"
with pytest.warns(UserWarning, match=match):
SemanticSegmentationTask(loss="jaccard", ignore_index=0)

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(SEN12MSDataModule, "plot", plot)
datamodule = SEN12MSDataModule(
Expand Down
23 changes: 10 additions & 13 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""Trainers for semantic segmentation."""

import os
import warnings
from typing import Any, Optional, Union

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -70,9 +69,6 @@ class and used with 'ce' loss.
freeze_decoder: Freeze the decoder network to linear probe
the segmentation head.

Warns:
UserWarning: When loss='jaccard' and ignore_index is specified.

.. versionchanged:: 0.3
*ignore_zeros* was renamed to *ignore_index*.

Expand All @@ -87,13 +83,10 @@ class and used with 'ce' loss.
The *weights* parameter now supports WeightEnums and checkpoint paths.
*learning_rate* and *learning_rate_schedule_patience* were renamed to
*lr* and *patience*.
"""
if ignore_index is not None and loss == "jaccard":
warnings.warn(
"ignore_index has no effect on training when loss='jaccard'",
UserWarning,
)

.. versionchanged:: 0.6
The *ignore_index* parameter now works for jaccard loss.
"""
self.weights = weights
super().__init__(ignore="weights")

Expand All @@ -111,9 +104,13 @@ def configure_losses(self) -> None:
ignore_index=ignore_value, weight=self.hparams["class_weights"]
)
elif loss == "jaccard":
self.criterion = smp.losses.JaccardLoss(
mode="multiclass", classes=self.hparams["num_classes"]
)
# JaccardLoss requires a list of classes to use instead of a class
# index to ignore.
classes = [
i for i in range(self.hparams["num_classes"]) if i != ignore_index
]

self.criterion = smp.losses.JaccardLoss(mode="multiclass", classes=classes)
elif loss == "focal":
self.criterion = smp.losses.FocalLoss(
"multiclass", ignore_index=ignore_index, normalized=True
Expand Down
Loading