Skip to content

Commit

Permalink
Add SouthAmericaSoybeanSentinel2DataModule
Browse files Browse the repository at this point in the history
  • Loading branch information
cookie-kyu committed Feb 12, 2024
1 parent 477ddd6 commit 419cb64
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .sen12ms import SEN12MSDataModule
from .skippd import SKIPPDDataModule
from .so2sat import So2SatDataModule
from .south_america_soybean import SouthAmericaSoybean
from .spacenet import SpaceNet1DataModule
from .ssl4eo import SSL4EOLDataModule, SSL4EOS12DataModule
from .ssl4eo_benchmark import SSL4EOLBenchmarkDataModule
Expand Down Expand Up @@ -71,6 +72,7 @@
"SEN12MSDataModule",
"SKIPPDDataModule",
"So2SatDataModule",
"SouthAmericaSoybeanSentinel2DataModule",
"SpaceNet1DataModule",
"SSL4EOLBenchmarkDataModule",
"SSL4EOLDataModule",
Expand Down
109 changes: 109 additions & 0 deletions torchgeo/datamodules/south_america_soybean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""South America Soybean datamodule."""

from typing import Any, Optional, Union

import kornia.augmentation as K
from matplotlib.figure import Figure

from ..datasets import SouthAmericaSoybean, BoundingBox, Sentinel2
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class SouthAmericaSoybeanSentinel2DataModule(GeoDataModule):
"""LightningDataModule implementation for SouthAmericaSoybean and Sentinel2 datasets.
Uses the train/val/test splits from the dataset.
"""

def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 16,
length: Optional[int] = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new SouthAmericaSoybeanDataModule instance.
Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.SouthAmericaSoybean` (prefix keys with ``south_america_soybean_``) and
:class:`~torchgeo.datasets.Sentinel2`
(prefix keys with ``sentinel2_``).
"""
self.southamericasoybean_kwargs = {}
self.sentinel2_kwargs = {}
for key, val in kwargs.items():
if key.startswith("south_america_soybean_"):
self.southamericasoybean_kwargs[key[22:]] = val
elif key.startswith("sentinel2_"):
self.sentinel2_kwargs[key[10:]] = val

super().__init__(
SouthAmericaSoybean,
batch_size,
patch_size,
length,
num_workers,
**self.south_america_soybean_kwargs,
)

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"]
)

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.sentinel2 = Sentinel2(**self.sentinel2_kwargs)
self.south_america_soybean = SouthAmericaSoybean(**self.eurocrops_kwargs)
self.dataset = self.sentinel2 & self.south_america_soybean

roi = self.dataset.bounds
midx = roi.minx + (roi.maxx - roi.minx) / 2
midy = roi.miny + (roi.maxy - roi.miny) / 2

if stage in ["fit"]:
train_roi = BoundingBox(
roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt
)
self.train_batch_sampler = RandomBatchGeoSampler(
self.dataset, self.patch_size, self.batch_size, self.length, train_roi
)
if stage in ["fit", "validate"]:
val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt)
self.val_sampler = GridGeoSampler(
self.dataset, self.patch_size, self.patch_size, val_roi
)
if stage in ["test"]:
test_roi = BoundingBox(
roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt
)
self.test_sampler = GridGeoSampler(
self.dataset, self.patch_size, self.patch_size, test_roi
)

def plot(self, *args: Any, **kwargs: Any) -> Figure:
"""Run SouthAmericaSoybean plot method.
Args:
*args: Arguments passed to plot method.
**kwargs: Keyword arguments passed to plot method.
Returns:
A matplotlib Figure with the image, ground truth, and predictions.
.. versionadded:: 0.4
"""

return self.south_america_soybean.plot(*args, **kwargs)

0 comments on commit 419cb64

Please sign in to comment.