Skip to content

Commit

Permalink
Use random_grid_cell_assignment for splits
Browse files Browse the repository at this point in the history
  • Loading branch information
cookie-kyu committed Feb 22, 2024
1 parent 602c4cc commit c9acafc
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions torchgeo/datamodules/south_america_soybean.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from typing import Any, Optional, Union

import kornia.augmentation as K
import torch
from matplotlib.figure import Figure

from ..datasets import BoundingBox, Sentinel2, SouthAmericaSoybean
from ..datasets import Sentinel2, SouthAmericaSoybean, random_grid_cell_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..transforms import AugmentationSequential
from .geo import GeoDataModule
Expand Down Expand Up @@ -74,28 +75,22 @@ def setup(self, stage: str) -> None:
)
self.dataset = self.sentinel2 & self.south_america_soybean

roi = self.dataset.bounds
midx = roi.minx + (roi.maxx - roi.minx) / 2
midy = roi.miny + (roi.maxy - roi.miny) / 2
generator = torch.Generator().manual_seed(1)
(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_grid_cell_assignment(self.dataset, [0.6, 0.2, 0.2], 2, generator)
)

if stage in ["fit"]:
train_roi = BoundingBox(
roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt
)
self.train_batch_sampler = RandomBatchGeoSampler(
self.dataset, self.patch_size, self.batch_size, self.length, train_roi
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ["fit", "validate"]:
val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt)
self.val_sampler = GridGeoSampler(
self.dataset, self.patch_size, self.patch_size, val_roi
self.val_dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
test_roi = BoundingBox(
roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt
)
self.test_sampler = GridGeoSampler(
self.dataset, self.patch_size, self.patch_size, test_roi
self.test_dataset, self.patch_size, self.patch_size
)

def plot(self, *args: Any, **kwargs: Any) -> Figure:
Expand Down

0 comments on commit c9acafc

Please sign in to comment.