diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 168a982adc8..204f28da775 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -110,9 +110,7 @@ def configure_losses(self) -> None: i for i in range(self.hparams["num_classes"]) if i != ignore_index ] - self.criterion = smp.losses.JaccardLoss( - mode="multiclass", classes=classes - ) + self.criterion = smp.losses.JaccardLoss(mode="multiclass", classes=classes) elif loss == "focal": self.criterion = smp.losses.FocalLoss( "multiclass", ignore_index=ignore_index, normalized=True