diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index d8b207d5d2d..7f4316cc021 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -62,7 +62,7 @@ class TestSemanticSegmentationTask: 'deepglobelandcover', 'etci2021', 'gid15', - 'inria', + 'inria' 'l7irish', 'l8biome', 'landcoverai', diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index c3d30e80259..45b850cff0f 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -54,10 +54,13 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 2b480352d72..0243010b179 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -122,8 +122,10 @@ def __init__( self.layers = ['naip-new', 'lc'] self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 3112290dea9..920cc1644d2 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -49,7 +49,10 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/fire_risk.py b/torchgeo/datamodules/fire_risk.py index cfa7452a6d7..58855d26fae 100644 --- a/torchgeo/datamodules/fire_risk.py +++ b/torchgeo/datamodules/fire_risk.py @@ -38,7 +38,10 @@ def __init__( K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index c5c86d75756..62a72756097 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -72,6 +72,8 @@ def __init__( self.aug: Transform = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] self.train_aug: Transform | None = None self.val_aug: Transform | None = None self.test_aug: Transform | None = None diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 06985a73bb0..a7e48226a1a 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -51,13 +51,19 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.predict_aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 7bd4d0ae165..de369fe67ed 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -49,18 +49,26 @@ def __init__( K.RandomVerticalFlip(p=0.5), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + self.predict_aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/l7irish.py b/torchgeo/datamodules/l7irish.py index aca9693ea9a..0f3ec561990 100644 --- a/torchgeo/datamodules/l7irish.py +++ b/torchgeo/datamodules/l7irish.py @@ -54,10 +54,13 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/l8biome.py b/torchgeo/datamodules/l8biome.py index e94db30d392..2c2ab4c2152 100644 --- a/torchgeo/datamodules/l8biome.py +++ b/torchgeo/datamodules/l8biome.py @@ -54,10 +54,13 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index d0b4cd62019..0f60c11bd35 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -38,7 +38,12 @@ def __init__( K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), data_keys=None, + keepdim=True, ) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 3f0343010a9..8a5bc9e8012 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -46,14 +46,20 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) self.val_aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) self.test_aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.val_aug.keepdim = True # type: ignore[attr-defined] + self.test_aug.keepdim = True # type: ignore[attr-defined] + class LEVIRCDPlusDataModule(NonGeoDataModule): """LightningDataModule implementation for the LEVIR-CD+ dataset. @@ -92,14 +98,20 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) self.val_aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) self.test_aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.val_aug.keepdim = True # type: ignore[attr-defined] + self.test_aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 7e56ccbb142..03cc8d35b9d 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -57,8 +57,10 @@ def __init__( ) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 414850426a0..298cd65b5b5 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -88,7 +88,10 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 452c5c47844..0526d5b6026 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -51,7 +51,10 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/quakeset.py b/torchgeo/datamodules/quakeset.py index 1a3e19a5122..2316c23b4ae 100644 --- a/torchgeo/datamodules/quakeset.py +++ b/torchgeo/datamodules/quakeset.py @@ -39,4 +39,7 @@ def __init__( K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), data_keys=['image'], + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index d908bb840c1..0a046a5efbb 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -44,4 +44,7 @@ def __init__( K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] diff --git a/torchgeo/datamodules/seco.py b/torchgeo/datamodules/seco.py index f18447acad1..b52f5506af5 100644 --- a/torchgeo/datamodules/seco.py +++ b/torchgeo/datamodules/seco.py @@ -53,7 +53,10 @@ def __init__( K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), K.Normalize(mean=_mean, std=_std), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/sentinel2_cdl.py b/torchgeo/datamodules/sentinel2_cdl.py index d0fad7caaac..b84dd9e755c 100644 --- a/torchgeo/datamodules/sentinel2_cdl.py +++ b/torchgeo/datamodules/sentinel2_cdl.py @@ -68,15 +68,20 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_eurocrops.py b/torchgeo/datamodules/sentinel2_eurocrops.py index 4e0893e4f8d..9eeea37d5cd 100644 --- a/torchgeo/datamodules/sentinel2_eurocrops.py +++ b/torchgeo/datamodules/sentinel2_eurocrops.py @@ -13,7 +13,6 @@ from ..datasets import EuroCrops, Sentinel2, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -64,21 +63,26 @@ def __init__( **self.eurocrops_kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_nccm.py b/torchgeo/datamodules/sentinel2_nccm.py index 6d306058dea..4897c267c16 100644 --- a/torchgeo/datamodules/sentinel2_nccm.py +++ b/torchgeo/datamodules/sentinel2_nccm.py @@ -68,15 +68,20 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/sentinel2_south_america_soybean.py b/torchgeo/datamodules/sentinel2_south_america_soybean.py index a7723a2ca5f..b3963b40c6b 100644 --- a/torchgeo/datamodules/sentinel2_south_america_soybean.py +++ b/torchgeo/datamodules/sentinel2_south_america_soybean.py @@ -67,15 +67,20 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) self.aug = K.AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=None + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/southafricacroptype.py b/torchgeo/datamodules/southafricacroptype.py index 3f44bb61471..45c0534815f 100644 --- a/torchgeo/datamodules/southafricacroptype.py +++ b/torchgeo/datamodules/southafricacroptype.py @@ -12,7 +12,6 @@ from ..datasets import SouthAfricaCropType, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,21 +48,26 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index ac37ae2069d..a7ce3ea657e 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -54,13 +54,19 @@ def __init__( K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), data_keys=None, + keepdim=True, ) self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.aug.keepdim = True # type: ignore[attr-defined] + def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/ssl4eo_benchmark.py b/torchgeo/datamodules/ssl4eo_benchmark.py index 2b4aba14422..46b57eba0ce 100644 --- a/torchgeo/datamodules/ssl4eo_benchmark.py +++ b/torchgeo/datamodules/ssl4eo_benchmark.py @@ -45,6 +45,7 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, @@ -53,9 +54,16 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), data_keys=None, + keepdim=True, ) self.test_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), data_keys=None, + keepdim=True, ) + + # https://github.com/kornia/kornia/issues/2848 + self.train_aug.keepdim = True # type: ignore[attr-defined] + self.val_aug.keepdim = True # type: ignore[attr-defined] + self.val_aug.keepdim = True # type: ignore[attr-defined] diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 34537cda365..abc30b29c37 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -34,4 +34,7 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), K.Resize(size=256), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 35eec7821ae..49414358cd1 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -51,7 +51,10 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), data_keys=None, + keepdim=True, ) + # https://github.com/kornia/kornia/issues/2848 + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 6b2601c851f..afd71521002 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -9,7 +9,6 @@ import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import torch.nn as nn -from einops import rearrange from matplotlib.figure import Figure from torch import Tensor from torchmetrics import MetricCollection @@ -226,9 +225,6 @@ def training_step( Returns: The loss tensor. """ - if 'mask' in batch and batch['mask'].shape[1] == 1: - batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') - x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -249,8 +245,6 @@ def validation_step( batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - if 'mask' in batch and batch['mask'].shape[1] == 1: - batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0] @@ -295,8 +289,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - if 'mask' in batch and batch['mask'].shape[1] == 1: - batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') x = batch['image'] y = batch['mask'] batch_size = x.shape[0]