Skip to content

Commit

Permalink
Add eurocrops data module. (#1869)
Browse files Browse the repository at this point in the history
* Add eurocrops data module.

It is based on NAIPChesapeakeDataModule which splits bounding box of dataset
into 1/2 train, 1/4 val, and 1/4 test. This may not be the best way to train
an actual model.

* misc fixes

* various fixes per discussion

* update eurocrops test data

* fix

* fix version added placement

* fix failing test by forcing integrity test when checksum is requested

* Clarify SIZE setting in eurocrops test data

* fix currently remaining issues with eurocrops data module

* fix style

* more style fix

* Add documentation

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
favyen2 and adamjstewart authored Apr 12, 2024
1 parent abceea0 commit 83353b0
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Sentinel
^^^^^^^^

.. autoclass:: Sentinel2CDLDataModule
.. autoclass:: Sentinel2EuroCropsDataModule
.. autoclass:: Sentinel2NCCMDataModule
.. autoclass:: Sentinel2SouthAmericaSoybeanDataModule

Expand Down
17 changes: 17 additions & 0 deletions tests/conf/sentinel2_eurocrops.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 13
num_classes: 3
num_filters: 1
data:
class_path: Sentinel2EuroCropsDataModule
init_args:
batch_size: 2
patch_size: 16
dict_kwargs:
sentinel2_paths: "tests/data/sentinel2"
eurocrops_paths: "tests/data/eurocrops"
Binary file modified tests/data/eurocrops/AA.zip
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/eurocrops/AA_2022_EC21.cpg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ISO-8859-1
Binary file added tests/data/eurocrops/AA_2022_EC21.dbf
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/eurocrops/AA_2022_EC21.prj
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PROJCS["WGS_1984_UTM_Zone_16N",GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Transverse_Mercator"],PARAMETER["False_Easting",500000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",-87.0],PARAMETER["Scale_Factor",0.9996],PARAMETER["Latitude_Of_Origin",0.0],UNIT["Meter",1.0]]
Binary file added tests/data/eurocrops/AA_2022_EC21.shp
Binary file not shown.
Binary file added tests/data/eurocrops/AA_2022_EC21.shx
Binary file not shown.
21 changes: 14 additions & 7 deletions tests/data/eurocrops/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,22 @@
from rasterio.crs import CRS
from shapely.geometry import Polygon, mapping

SIZE = 100
# Size of example crop field polygon in projection units.
# This is set to align with Sentinel-2 test data, which is a 128x128 image at 10
# projection units per pixel (1280x1280 projection units).
SIZE = 1280


def create_data_file(dataname):
schema = {"geometry": "Polygon", "properties": {"EC_hcat_c": "str"}}
with fiona.open(
dataname, "w", crs=CRS.from_epsg(31287), driver="ESRI Shapefile", schema=schema
dataname, "w", crs=CRS.from_epsg(32616), driver="ESRI Shapefile", schema=schema
) as shpfile:
coordinates = [[0.0, 0.0], [0.0, SIZE], [SIZE, SIZE], [SIZE, 0.0], [0.0, 0.0]]
# The offset aligns with tests/data/sentinel2/data.py.
offset = [399960, 4500000 - SIZE]
coordinates = [[x + offset[0], y + offset[1]] for x, y in coordinates]

polygon = Polygon(coordinates)
properties = {"EC_hcat_c": "1000000010"}
shpfile.write({"geometry": mapping(polygon), "properties": properties})
Expand All @@ -36,12 +43,12 @@ def create_csv(fname):

if __name__ == "__main__":
csvname = "HCAT2.csv"
dataname = "AA_2021_EC21.shp"
dataname = "AA_2022_EC21.shp"
supportnames = [
"AA_2021_EC21.cpg",
"AA_2021_EC21.dbf",
"AA_2021_EC21.prj",
"AA_2021_EC21.shx",
"AA_2022_EC21.cpg",
"AA_2022_EC21.dbf",
"AA_2022_EC21.prj",
"AA_2022_EC21.shx",
]
zipfilename = "AA.zip"

Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def dataset(
monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
monkeypatch.setattr(torchgeo.datasets.eurocrops, "download_url", download_url)
monkeypatch.setattr(
EuroCrops, "zenodo_files", [("AA.zip", "9a45308fc32b535318562c1394fd2911")]
EuroCrops, "zenodo_files", [("AA.zip", "b2ef5cac231294731c1dfea47cba544d")]
)
monkeypatch.setattr(EuroCrops, "hcat_md5", "22d61cf3b316c8babfd209ae81419d8f")
base_url = os.path.join("tests", "data", "eurocrops") + os.sep
Expand Down
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class TestSemanticSegmentationTask:
"sen12ms_s2_all",
"sen12ms_s2_reduced",
"sentinel2_cdl",
"sentinel2_eurocrops",
"sentinel2_nccm",
"sentinel2_south_america_soybean",
"spacenet1",
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .seco import SeasonalContrastS2DataModule
from .sen12ms import SEN12MSDataModule
from .sentinel2_cdl import Sentinel2CDLDataModule
from .sentinel2_eurocrops import Sentinel2EuroCropsDataModule
from .sentinel2_nccm import Sentinel2NCCMDataModule
from .sentinel2_south_america_soybean import Sentinel2SouthAmericaSoybeanDataModule
from .skippd import SKIPPDDataModule
Expand All @@ -53,6 +54,7 @@
"L8BiomeDataModule",
"NAIPChesapeakeDataModule",
"Sentinel2CDLDataModule",
"Sentinel2EuroCropsDataModule",
"Sentinel2NCCMDataModule",
"Sentinel2SouthAmericaSoybeanDataModule",
# NonGeoDataset
Expand Down
121 changes: 121 additions & 0 deletions torchgeo/datamodules/sentinel2_eurocrops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""EuroCrops datamodule."""

from typing import Any

import kornia.augmentation as K
import torch
from kornia.constants import DataKey, Resample
from matplotlib.figure import Figure

from ..datasets import EuroCrops, Sentinel2, random_grid_cell_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class Sentinel2EuroCropsDataModule(GeoDataModule):
"""LightningDataModule implementation for the EuroCrops and Sentinel2 datasets.
Uses the train/val/test splits from the dataset.
.. versionadded:: 0.6
"""

def __init__(
self,
batch_size: int = 64,
patch_size: int | tuple[int, int] = 256,
length: int | None = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new Sentinel2EuroCropsDataModule instance.
Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.EuroCrops` (prefix keys with ``eurocrops_``)
and :class:`~torchgeo.datasets.Sentinel2`
(prefix keys with ``sentinel2_``).
"""
eurocrops_signature = "eurocrops_"
sentinel2_signature = "sentinel2_"
self.eurocrops_kwargs = {}
self.sentinel2_kwargs = {}
for key, val in kwargs.items():
if key.startswith(eurocrops_signature):
self.eurocrops_kwargs[key[len(eurocrops_signature) :]] = val
elif key.startswith(sentinel2_signature):
self.sentinel2_kwargs[key[len(sentinel2_signature) :]] = val

super().__init__(
EuroCrops,
batch_size,
patch_size,
length,
num_workers,
**self.eurocrops_kwargs,
)

self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=["image", "mask"],
extra_args={
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
},
)

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"]
)

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.sentinel2 = Sentinel2(**self.sentinel2_kwargs)
self.eurocrops = EuroCrops(**self.eurocrops_kwargs)
self.dataset = self.sentinel2 & self.eurocrops

generator = torch.Generator().manual_seed(0)
(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_grid_cell_assignment(
self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator
)
)
if stage in ["fit"]:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ["fit", "validate"]:
self.val_sampler = GridGeoSampler(
self.val_dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)

def plot(self, *args: Any, **kwargs: Any) -> Figure:
"""Run EuroCrops plot method.
Args:
*args: Arguments passed to plot method.
**kwargs: Keyword arguments passed to plot method.
Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
return self.eurocrops.plot(*args, **kwargs)
4 changes: 4 additions & 0 deletions torchgeo/datasets/eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def _check_integrity(self) -> bool:
Returns:
True if dataset files are found and/or MD5s match, else False
"""
# Check if the extracted files already exist
if self.files and not self.checksum:
return True

assert isinstance(self.paths, str)

filepath = os.path.join(self.paths, self.hcat_fname)
Expand Down
14 changes: 14 additions & 0 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,19 @@ class VectorDataset(GeoDataset):
#: Not used if :attr:`filename_regex` does not contain a ``date`` group.
date_format = "%Y%m%d"

@property
def dtype(self) -> torch.dtype:
"""The dtype of the dataset (overrides the dtype of the data file via a cast).
Defaults to long.
Returns:
the dtype of the dataset
.. versionadded:: 0.6
"""
return torch.long

def __init__(
self,
paths: str | Iterable[str] = "data",
Expand Down Expand Up @@ -734,6 +747,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
# Use array_to_tensor since rasterize may return uint16/uint32 arrays.
masks = array_to_tensor(masks)

masks = masks.to(self.dtype)
sample = {"mask": masks, "crs": self.crs, "bbox": query}

if self.transforms is not None:
Expand Down

0 comments on commit 83353b0

Please sign in to comment.