diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 87fd8b48bf1..f41463b9b76 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -35,6 +35,7 @@ Sentinel ^^^^^^^^ .. autoclass:: Sentinel2CDLDataModule +.. autoclass:: Sentinel2EuroCropsDataModule .. autoclass:: Sentinel2NCCMDataModule .. autoclass:: Sentinel2SouthAmericaSoybeanDataModule diff --git a/tests/conf/sentinel2_eurocrops.yaml b/tests/conf/sentinel2_eurocrops.yaml new file mode 100644 index 00000000000..b3633d3590d --- /dev/null +++ b/tests/conf/sentinel2_eurocrops.yaml @@ -0,0 +1,17 @@ +model: + class_path: SemanticSegmentationTask + init_args: + loss: "ce" + model: "unet" + backbone: "resnet18" + in_channels: 13 + num_classes: 3 + num_filters: 1 +data: + class_path: Sentinel2EuroCropsDataModule + init_args: + batch_size: 2 + patch_size: 16 + dict_kwargs: + sentinel2_paths: "tests/data/sentinel2" + eurocrops_paths: "tests/data/eurocrops" diff --git a/tests/data/eurocrops/AA.zip b/tests/data/eurocrops/AA.zip index 8ac0988a4da..635f499dbf9 100644 Binary files a/tests/data/eurocrops/AA.zip and b/tests/data/eurocrops/AA.zip differ diff --git a/tests/data/eurocrops/AA_2022_EC21.cpg b/tests/data/eurocrops/AA_2022_EC21.cpg new file mode 100644 index 00000000000..cd89cb9758e --- /dev/null +++ b/tests/data/eurocrops/AA_2022_EC21.cpg @@ -0,0 +1 @@ +ISO-8859-1 \ No newline at end of file diff --git a/tests/data/eurocrops/AA_2022_EC21.dbf b/tests/data/eurocrops/AA_2022_EC21.dbf new file mode 100644 index 00000000000..dc5993db2d9 Binary files /dev/null and b/tests/data/eurocrops/AA_2022_EC21.dbf differ diff --git a/tests/data/eurocrops/AA_2022_EC21.prj b/tests/data/eurocrops/AA_2022_EC21.prj new file mode 100644 index 00000000000..c6996cc55e2 --- /dev/null +++ b/tests/data/eurocrops/AA_2022_EC21.prj @@ -0,0 +1 @@ +PROJCS["WGS_1984_UTM_Zone_16N",GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Transverse_Mercator"],PARAMETER["False_Easting",500000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",-87.0],PARAMETER["Scale_Factor",0.9996],PARAMETER["Latitude_Of_Origin",0.0],UNIT["Meter",1.0]] \ No newline at end of file diff --git a/tests/data/eurocrops/AA_2022_EC21.shp b/tests/data/eurocrops/AA_2022_EC21.shp new file mode 100644 index 00000000000..fa5235ec666 Binary files /dev/null and b/tests/data/eurocrops/AA_2022_EC21.shp differ diff --git a/tests/data/eurocrops/AA_2022_EC21.shx b/tests/data/eurocrops/AA_2022_EC21.shx new file mode 100644 index 00000000000..d958674b5de Binary files /dev/null and b/tests/data/eurocrops/AA_2022_EC21.shx differ diff --git a/tests/data/eurocrops/data.py b/tests/data/eurocrops/data.py index b8ea55e6df2..54a5867e01b 100755 --- a/tests/data/eurocrops/data.py +++ b/tests/data/eurocrops/data.py @@ -12,15 +12,22 @@ from rasterio.crs import CRS from shapely.geometry import Polygon, mapping -SIZE = 100 +# Size of example crop field polygon in projection units. +# This is set to align with Sentinel-2 test data, which is a 128x128 image at 10 +# projection units per pixel (1280x1280 projection units). +SIZE = 1280 def create_data_file(dataname): schema = {"geometry": "Polygon", "properties": {"EC_hcat_c": "str"}} with fiona.open( - dataname, "w", crs=CRS.from_epsg(31287), driver="ESRI Shapefile", schema=schema + dataname, "w", crs=CRS.from_epsg(32616), driver="ESRI Shapefile", schema=schema ) as shpfile: coordinates = [[0.0, 0.0], [0.0, SIZE], [SIZE, SIZE], [SIZE, 0.0], [0.0, 0.0]] + # The offset aligns with tests/data/sentinel2/data.py. + offset = [399960, 4500000 - SIZE] + coordinates = [[x + offset[0], y + offset[1]] for x, y in coordinates] + polygon = Polygon(coordinates) properties = {"EC_hcat_c": "1000000010"} shpfile.write({"geometry": mapping(polygon), "properties": properties}) @@ -36,12 +43,12 @@ def create_csv(fname): if __name__ == "__main__": csvname = "HCAT2.csv" - dataname = "AA_2021_EC21.shp" + dataname = "AA_2022_EC21.shp" supportnames = [ - "AA_2021_EC21.cpg", - "AA_2021_EC21.dbf", - "AA_2021_EC21.prj", - "AA_2021_EC21.shx", + "AA_2022_EC21.cpg", + "AA_2022_EC21.dbf", + "AA_2022_EC21.prj", + "AA_2022_EC21.shx", ] zipfilename = "AA.zip" diff --git a/tests/datasets/test_eurocrops.py b/tests/datasets/test_eurocrops.py index e13da1f4697..0477b8617ac 100644 --- a/tests/datasets/test_eurocrops.py +++ b/tests/datasets/test_eurocrops.py @@ -36,7 +36,7 @@ def dataset( monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) monkeypatch.setattr(torchgeo.datasets.eurocrops, "download_url", download_url) monkeypatch.setattr( - EuroCrops, "zenodo_files", [("AA.zip", "9a45308fc32b535318562c1394fd2911")] + EuroCrops, "zenodo_files", [("AA.zip", "b2ef5cac231294731c1dfea47cba544d")] ) monkeypatch.setattr(EuroCrops, "hcat_md5", "22d61cf3b316c8babfd209ae81419d8f") base_url = os.path.join("tests", "data", "eurocrops") + os.sep diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 40da4efead6..42fe7693e0b 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -74,6 +74,7 @@ class TestSemanticSegmentationTask: "sen12ms_s2_all", "sen12ms_s2_reduced", "sentinel2_cdl", + "sentinel2_eurocrops", "sentinel2_nccm", "sentinel2_south_america_soybean", "spacenet1", diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index a2d3eb1a666..5ee3a47aaaa 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -30,6 +30,7 @@ from .seco import SeasonalContrastS2DataModule from .sen12ms import SEN12MSDataModule from .sentinel2_cdl import Sentinel2CDLDataModule +from .sentinel2_eurocrops import Sentinel2EuroCropsDataModule from .sentinel2_nccm import Sentinel2NCCMDataModule from .sentinel2_south_america_soybean import Sentinel2SouthAmericaSoybeanDataModule from .skippd import SKIPPDDataModule @@ -53,6 +54,7 @@ "L8BiomeDataModule", "NAIPChesapeakeDataModule", "Sentinel2CDLDataModule", + "Sentinel2EuroCropsDataModule", "Sentinel2NCCMDataModule", "Sentinel2SouthAmericaSoybeanDataModule", # NonGeoDataset diff --git a/torchgeo/datamodules/sentinel2_eurocrops.py b/torchgeo/datamodules/sentinel2_eurocrops.py new file mode 100644 index 00000000000..076131f8a89 --- /dev/null +++ b/torchgeo/datamodules/sentinel2_eurocrops.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""EuroCrops datamodule.""" + +from typing import Any + +import kornia.augmentation as K +import torch +from kornia.constants import DataKey, Resample +from matplotlib.figure import Figure + +from ..datasets import EuroCrops, Sentinel2, random_grid_cell_assignment +from ..samplers import GridGeoSampler, RandomBatchGeoSampler +from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from .geo import GeoDataModule + + +class Sentinel2EuroCropsDataModule(GeoDataModule): + """LightningDataModule implementation for the EuroCrops and Sentinel2 datasets. + + Uses the train/val/test splits from the dataset. + + .. versionadded:: 0.6 + """ + + def __init__( + self, + batch_size: int = 64, + patch_size: int | tuple[int, int] = 256, + length: int | None = None, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new Sentinel2EuroCropsDataModule instance. + + Args: + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + length: Length of each training epoch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.EuroCrops` (prefix keys with ``eurocrops_``) + and :class:`~torchgeo.datasets.Sentinel2` + (prefix keys with ``sentinel2_``). + """ + eurocrops_signature = "eurocrops_" + sentinel2_signature = "sentinel2_" + self.eurocrops_kwargs = {} + self.sentinel2_kwargs = {} + for key, val in kwargs.items(): + if key.startswith(eurocrops_signature): + self.eurocrops_kwargs[key[len(eurocrops_signature) :]] = val + elif key.startswith(sentinel2_signature): + self.sentinel2_kwargs[key[len(sentinel2_signature) :]] = val + + super().__init__( + EuroCrops, + batch_size, + patch_size, + length, + num_workers, + **self.eurocrops_kwargs, + ) + + self.train_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), + K.RandomVerticalFlip(p=0.5), + K.RandomHorizontalFlip(p=0.5), + data_keys=["image", "mask"], + extra_args={ + DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + }, + ) + + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + ) + + def setup(self, stage: str) -> None: + """Set up datasets and samplers. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + self.sentinel2 = Sentinel2(**self.sentinel2_kwargs) + self.eurocrops = EuroCrops(**self.eurocrops_kwargs) + self.dataset = self.sentinel2 & self.eurocrops + + generator = torch.Generator().manual_seed(0) + (self.train_dataset, self.val_dataset, self.test_dataset) = ( + random_grid_cell_assignment( + self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator + ) + ) + if stage in ["fit"]: + self.train_batch_sampler = RandomBatchGeoSampler( + self.train_dataset, self.patch_size, self.batch_size, self.length + ) + if stage in ["fit", "validate"]: + self.val_sampler = GridGeoSampler( + self.val_dataset, self.patch_size, self.patch_size + ) + if stage in ["test"]: + self.test_sampler = GridGeoSampler( + self.test_dataset, self.patch_size, self.patch_size + ) + + def plot(self, *args: Any, **kwargs: Any) -> Figure: + """Run EuroCrops plot method. + + Args: + *args: Arguments passed to plot method. + **kwargs: Keyword arguments passed to plot method. + + Returns: + A matplotlib Figure with the image, ground truth, and predictions. + """ + return self.eurocrops.plot(*args, **kwargs) diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index b1cfa57b211..140e6f310c6 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -138,6 +138,10 @@ def _check_integrity(self) -> bool: Returns: True if dataset files are found and/or MD5s match, else False """ + # Check if the extracted files already exist + if self.files and not self.checksum: + return True + assert isinstance(self.paths, str) filepath = os.path.join(self.paths, self.hcat_fname) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index f175df463cb..590f6f7dcf0 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -608,6 +608,19 @@ class VectorDataset(GeoDataset): #: Not used if :attr:`filename_regex` does not contain a ``date`` group. date_format = "%Y%m%d" + @property + def dtype(self) -> torch.dtype: + """The dtype of the dataset (overrides the dtype of the data file via a cast). + + Defaults to long. + + Returns: + the dtype of the dataset + + .. versionadded:: 0.6 + """ + return torch.long + def __init__( self, paths: str | Iterable[str] = "data", @@ -734,6 +747,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: # Use array_to_tensor since rasterize may return uint16/uint32 arrays. masks = array_to_tensor(masks) + masks = masks.to(self.dtype) sample = {"mask": masks, "crs": self.crs, "bbox": query} if self.transforms is not None: