Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sentinel2_CDL Datamodule #1889

Merged
merged 47 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3107838
cdlsentinel2
yichiac Feb 17, 2024
0b77b32
update kwargs
yichiac Feb 17, 2024
2e0c93e
style
yichiac Feb 17, 2024
25c8d64
arg type
yichiac Feb 18, 2024
c41ebf1
add cov
yichiac Feb 18, 2024
584fcca
kwargs
yichiac Feb 18, 2024
9514bf3
update cdl data.py for intersection
yichiac Feb 18, 2024
4a60675
style
yichiac Feb 18, 2024
5512164
create 2022 cdl for intersection
yichiac Feb 19, 2024
5ff55cd
test roi method
yichiac Feb 19, 2024
8716e40
style
yichiac Feb 19, 2024
45eabd4
test_cdl year update
yichiac Feb 19, 2024
b333ac4
intersection
yichiac Feb 19, 2024
7472db4
random_grid_cell_assignment
yichiac Feb 21, 2024
1e8e884
Merge branch 'main' into datamodules/cdlsentinel2
yichiac Feb 21, 2024
ad235a5
Merge branch 'microsoft:main' into datamodules/cdlsentinel2
yichiac Mar 3, 2024
5856fcf
add comments and line
yichiac Mar 3, 2024
615fd4d
add description
yichiac Mar 3, 2024
10fe0ac
Merge branch 'microsoft:main' into datamodules/cdlsentinel2
yichiac Mar 3, 2024
73141a6
add doc
yichiac Mar 4, 2024
aa67fba
Merge branch 'microsoft:main' into datamodules/cdlsentinel2
yichiac Mar 7, 2024
cb47ba1
Merge branch 'microsoft:main' into datamodules/cdlsentinel2
yichiac Mar 8, 2024
fae737c
Merge branch 'main' into datamodules/cdlsentinel2
yichiac Mar 12, 2024
86e968d
Update SIZE variable in sentinel2/data.py and test stage in datamodul…
yichiac Mar 12, 2024
8e1bedc
Merge branch 'main' into datamodules/cdlsentinel2
yichiac Mar 12, 2024
424fbe2
merge val_aug and test_aug to aug
yichiac Mar 13, 2024
e67ef80
rename cdlsentinel2 to sentinel2cdl
yichiac Mar 14, 2024
9c02a9e
Merge branch 'main' into datamodules/cdlsentinel2
yichiac Mar 14, 2024
e71a9a6
fix isort
yichiac Mar 14, 2024
4d8c742
No need to monkeypatch CDL
adamjstewart Mar 15, 2024
01010a2
Smaller backbone == faster tests
adamjstewart Mar 15, 2024
3bab1a3
Sort docs alphabetically
adamjstewart Mar 15, 2024
44fb98f
Smaller Sentinel-2 test files
adamjstewart Mar 15, 2024
4b0cdf3
Smaller CDL files, don't delete directory
adamjstewart Mar 15, 2024
ff42af3
Sort imports alphabetically
adamjstewart Mar 15, 2024
5b3a69f
Fix doc names
adamjstewart Mar 15, 2024
859f24e
extra_args not needed
adamjstewart Mar 15, 2024
e0a9893
center crop doesn't do anything
adamjstewart Mar 15, 2024
4c05652
blacken
adamjstewart Mar 15, 2024
9acf859
Revert "extra_args not needed"
adamjstewart Mar 15, 2024
ec1ae23
Add underscore to filename
adamjstewart Mar 15, 2024
dd9deca
Add plot method
adamjstewart Mar 15, 2024
705b3ca
import Figure
yichiac Mar 15, 2024
7df5ef8
Merge branch 'main' into datamodules/cdlsentinel2
yichiac Mar 15, 2024
129ac6c
Merge branch 'microsoft:main' into datamodules/cdlsentinel2
yichiac Mar 19, 2024
41be84a
split 80-10-10
yichiac Mar 19, 2024
c35a3b2
style
yichiac Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ NAIP

.. autoclass:: NAIPChesapeakeDataModule

Sentinel
^^^^^^^^

.. autoclass:: Sentinel2CDLDataModule

Non-geospatial DataModules
--------------------------

Expand Down
18 changes: 18 additions & 0 deletions tests/conf/sentinel2cdl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 13
num_classes: 134
num_filters: 1
ignore_index: 0
data:
class_path: Sentinel2CDLDataModule
init_args:
batch_size: 2
patch_size: 16
dict_kwargs:
cdl_paths: "tests/data/cdl"
sentinel2_paths: "tests/data/sentinel2"
Binary file removed tests/data/cdl/2020_30m_cdls.zip
Binary file not shown.
Binary file removed tests/data/cdl/2020_30m_cdls/2020_30m_cdls.tif
Binary file not shown.
Binary file removed tests/data/cdl/2020_30m_cdls/2020_30m_cdls.tif.ovr
Binary file not shown.
Binary file removed tests/data/cdl/2021_30m_cdls.zip
Binary file not shown.
Binary file removed tests/data/cdl/2021_30m_cdls/2021_30m_cdls.tif
Binary file not shown.
Binary file removed tests/data/cdl/2021_30m_cdls/2021_30m_cdls.tif.ovr
Binary file not shown.
Binary file added tests/data/cdl/2022_30m_cdls.zip
Binary file not shown.
Binary file added tests/data/cdl/2022_30m_cdls/2022_30m_cdls.tif
Binary file not shown.
Binary file not shown.
Binary file added tests/data/cdl/2023_30m_cdls.zip
Binary file not shown.
Binary file added tests/data/cdl/2023_30m_cdls/2023_30m_cdls.tif
Binary file not shown.
Binary file not shown.
11 changes: 5 additions & 6 deletions tests/data/cdl/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

import numpy as np
import rasterio
from rasterio import Affine

SIZE = 32
SIZE = 128

np.random.seed(0)
random.seed(0)
Expand All @@ -22,8 +23,8 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
profile["driver"] = "GTiff"
profile["dtype"] = dtype
profile["count"] = num_channels
profile["crs"] = "epsg:4326"
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
profile["crs"] = "epsg:32616"
profile["transform"] = Affine(30, 0.0, 399960.0, 0.0, -30, 4500000.0)
profile["height"] = SIZE
profile["width"] = SIZE
profile["compress"] = "lzw"
Expand All @@ -49,7 +50,7 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
src.write_colormap(1, cmap)


directories = ["2020_30m_cdls", "2021_30m_cdls"]
directories = ["2023_30m_cdls", "2022_30m_cdls"]
raster_extensions = [".tif", ".tif.ovr"]


Expand Down Expand Up @@ -77,5 +78,3 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
with open(filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{filename}: {md5}")

shutil.rmtree(dir)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/data/sentinel2/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rasterio import Affine
from rasterio.crs import CRS

SIZE = 36
SIZE = 128

np.random.seed(0)

Expand Down
14 changes: 7 additions & 7 deletions tests/datasets/test_cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL:
monkeypatch.setattr(torchgeo.datasets.cdl, "download_url", download_url)

md5s = {
2021: "e929beb9c8e59fa1d7b7f82e64edaae1",
2020: "e95c2d40ce0c261ed6ee0bd00b49e4b6",
2023: "3fbd3eecf92b8ce1ae35060ada463c6d",
2022: "826c6fd639d9cdd94a44302fbc5b76c3",
}
monkeypatch.setattr(CDL, "md5s", md5s)
url = os.path.join("tests", "data", "cdl", "{}_30m_cdls.zip")
Expand All @@ -48,7 +48,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL:
transforms=transforms,
download=True,
checksum=True,
years=[2020, 2021],
years=[2023, 2022],
)

def test_getitem(self, dataset: CDL) -> None:
Expand All @@ -60,7 +60,7 @@ def test_getitem(self, dataset: CDL) -> None:
def test_classes(self) -> None:
root = os.path.join("tests", "data", "cdl")
classes = list(CDL.cmap.keys())[:5]
ds = CDL(root, years=[2021], classes=classes)
ds = CDL(root, years=[2023], classes=classes)
sample = ds[ds.bounds]
mask = sample["mask"]
assert mask.max() < len(classes)
Expand All @@ -75,19 +75,19 @@ def test_or(self, dataset: CDL) -> None:

def test_full_year(self, dataset: CDL) -> None:
bbox = dataset.bounds
time = datetime(2021, 6, 1).timestamp()
time = datetime(2023, 6, 1).timestamp()
query = BoundingBox(bbox.minx, bbox.maxx, bbox.miny, bbox.maxy, time, time)
next(dataset.index.intersection(tuple(query)))

def test_already_extracted(self, dataset: CDL) -> None:
CDL(dataset.paths, years=[2020, 2021])
CDL(dataset.paths, years=[2023, 2022])

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "cdl", "*_30m_cdls.zip")
root = str(tmp_path)
for zipfile in glob.iglob(pathname):
shutil.copy(zipfile, root)
CDL(root, years=[2020, 2021])
CDL(root, years=[2023, 2022])

def test_invalid_year(self, tmp_path: Path) -> None:
with pytest.raises(
Expand Down
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class TestSemanticSegmentationTask:
"sen12ms_s1",
"sen12ms_s2_all",
"sen12ms_s2_reduced",
"sentinel2cdl",
"spacenet1",
"ssl4eo_l_benchmark_cdl",
"ssl4eo_l_benchmark_nlcd",
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .resisc45 import RESISC45DataModule
from .seco import SeasonalContrastS2DataModule
from .sen12ms import SEN12MSDataModule
from .sentinel2cdl import Sentinel2CDLDataModule
from .skippd import SKIPPDDataModule
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
Expand All @@ -43,6 +44,7 @@

__all__ = (
# GeoDataset
"Sentinel2CDLDataModule",
"ChesapeakeCVPRDataModule",
"L7IrishDataModule",
"L8BiomeDataModule",
Expand Down
110 changes: 110 additions & 0 deletions torchgeo/datamodules/sentinel2cdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""CDLSentinel2 datamodule."""

from typing import Any, Optional, Union

import kornia.augmentation as K
import torch
from kornia.constants import DataKey, Resample

from ..datasets import CDL, Sentinel2, random_grid_cell_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class Sentinel2CDLDataModule(GeoDataModule):
"""LightningDataModule implementation for the CDL dataset.

.. versionadded:: 0.6
"""

def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 16,
length: Optional[int] = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new Sentinel2CDLDataModule instance.

Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.CDL` (prefix keys with ``cdl_``) and
:class:`~torchgeo.datasets.Sentinel2`
(prefix keys with ``sentinel2_``).
"""
# Define prefix for Cropland Data Layer (CDL) and Sentinel-2 arguments
cdl_signature = "cdl_"
sentinel2_signature = "sentinel2_"
self.cdl_kwargs = {}
self.sentinel2_kwargs = {}

for key, val in kwargs.items():
# Check if the current key starts with the CDL prefix
if key.startswith(cdl_signature):
# If so, extract the key-value pair to the CDL dictionary
self.cdl_kwargs[key[len(cdl_signature) :]] = val
# Check if the current key starts with the Sentinel-2 prefix
elif key.startswith(sentinel2_signature):
# If so, extract the key-value pair to the Sentinel-2 dictionary
self.sentinel2_kwargs[key[len(sentinel2_signature) :]] = val

super().__init__(
CDL, batch_size, patch_size, length, num_workers, **self.cdl_kwargs
)

self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
yichiac marked this conversation as resolved.
Show resolved Hide resolved
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=["image", "mask"],
extra_args={
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
},
)

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.CenterCrop(self.patch_size),
data_keys=["image", "mask"],
)

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.

Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.sentinel2 = Sentinel2(**self.sentinel2_kwargs)
self.cdl = CDL(**self.cdl_kwargs)
self.dataset = self.sentinel2 & self.cdl

generator = torch.Generator().manual_seed(0)

(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_grid_cell_assignment(
self.dataset, [0.7, 0.10, 0.20], grid_size=8, generator=generator
yichiac marked this conversation as resolved.
Show resolved Hide resolved
)
)
if stage in ["fit"]:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ["fit", "validate"]:
self.val_sampler = GridGeoSampler(
self.val_dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)
Loading