diff --git a/torchgeo/datamodules/nccm.py b/torchgeo/datamodules/nccm.py index bd39c6bae5e..310b8f2aeb2 100644 --- a/torchgeo/datamodules/nccm.py +++ b/torchgeo/datamodules/nccm.py @@ -21,6 +21,7 @@ class NCCMSentinel2DataModule(GeoDataModule): """LightningDataModule implementation for the NCCM and Sentinel2 datasets. Uses the train/val/test splits from the dataset. + .. versionadded:: 0.6 """ @@ -67,12 +68,7 @@ def __init__( }, ) - self.val_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - K.CenterCrop(self.patch_size), - data_keys=["image", "mask"], - ) - self.test_aug = AugmentationSequential( + self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), data_keys=["image", "mask"], diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 8703999988b..ef9bb05f10f 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -122,7 +122,6 @@ def __init__( self.length += 1 self.hits.append(hit) areas.append(bounds.area) - if length is not None: self.length = length diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 8a6ca077735..0be8a2e195d 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -254,7 +254,6 @@ def validation_step( batch["prediction"] = y_hat.argmax(dim=1) for key in ["image", "mask", "prediction"]: batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] fig: Optional[Figure] = None