Skip to content

Commit

Permalink
Add AgriFieldNet India Challenge dataset (#1459)
Browse files Browse the repository at this point in the history
* add agrifieldnet dataset

* modified len check

* improve _download

* remove augmentation and wrong datamodule names

* update data.py and dataset

* update splits

* remove patch_size change

* fix style issues

* add yaml and modify/test for training

* fix data path and add trainer

* export prediction

* fix integrity check and len

* extract predction

* adding create submission file function

* adding create submission file function

* hyperparam tuning exp

* backup experiments

* remove redundant files

* reverse segmentation.py

* resolve minor issues

* modify yaml and add exp files

* update data.py

* remove outdated train.py

* update dataset, test, and new data

* fix style

* fix doc api

* remove datamodule

* fix geo_datasets.csv

* fix codecov

* fix read tif issue

* Update torchgeo/datasets/agrifieldnet.py

Co-authored-by: Adam J. Stewart <[email protected]>

* fix init

* add ordinal_cmap to pred and remove comments

* remove suffix

* remove download entirely

* style

* Update agrifieldnet.py

Co-authored-by: Adam J. Stewart <[email protected]>

* Update agrifieldnet.py

Co-authored-by: Adam J. Stewart <[email protected]>

* remove url and if statement

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
yichiac and adamjstewart authored Feb 12, 2024
1 parent ff9555a commit 8af188c
Show file tree
Hide file tree
Showing 78 changed files with 465 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ Aboveground Woody Biomass

.. autoclass:: AbovegroundLiveWoodyBiomassDensity

AgriFieldNet
^^^^^^^^^^^^

.. autoclass:: AgriFieldNet

Airphen
^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/geo_datasets.csv
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Dataset,Type,Source,License,Size (px),Resolution (m)
`Aboveground Woody Biomass`_,Masks,"Landsat, LiDAR","CC-BY-4.0","40,000x40,000",30
`AgriFieldNet`_,"Imagery, Masks",Sentinel-2,"CC-BY-4.0","256x256",10
`Airphen`_,Imagery,Airphen,-,"1,280x960",0.047--0.09
`Aster Global DEM`_,Masks,Aster,"public domain","3,601x3,601",30
`Canadian Building Footprints`_,Geometries,Bing Imagery,"ODbL-1.0",-,-
Expand Down
108 changes: 108 additions & 0 deletions tests/data/agrifieldnet/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine


def generate_test_data(paths: str) -> str:
"""Create test data archive for AgriFieldNet dataset.
Args:
paths: path to store test data
n_samples: number of samples.
Returns:
md5 hash of created archive
"""
dtype = np.uint8
dtype_max = np.iinfo(dtype).max

SIZE = 32

np.random.seed(0)

bands = (
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B11",
"B12",
)

profile = {
"dtype": dtype,
"width": SIZE,
"height": SIZE,
"count": 1,
"crs": CRS.from_epsg(32644),
"transform": Affine(10.0, 0.0, 535840.0, 0.0, -10.0, 3079680.0),
}

source_dir = os.path.join(paths, "source")
train_mask_dir = os.path.join(paths, "train_labels")
test_field_dir = os.path.join(paths, "test_labels")

os.makedirs(source_dir, exist_ok=True)
os.makedirs(train_mask_dir, exist_ok=True)
os.makedirs(test_field_dir, exist_ok=True)

source_unique_folder_ids = ["32407", "8641e", "a419f", "eac11", "ff450"]
train_folder_ids = source_unique_folder_ids[0:5]
test_folder_ids = source_unique_folder_ids[3:5]

for id in source_unique_folder_ids:
directory = os.path.join(
source_dir, "ref_agrifieldnet_competition_v1_source_" + id
)
os.makedirs(directory, exist_ok=True)

for band in bands:
train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype)
path = os.path.join(
directory, f"ref_agrifieldnet_competition_v1_source_{id}_{band}_10m.tif"
)
with rasterio.open(path, "w", **profile) as src:
src.write(train_arr, 1)

for id in train_folder_ids:
train_mask_arr = np.random.randint(size=(SIZE, SIZE), low=0, high=6)
path = os.path.join(
train_mask_dir, f"ref_agrifieldnet_competition_v1_labels_train_{id}.tif"
)
with rasterio.open(path, "w", **profile) as src:
src.write(train_mask_arr, 1)

train_field_arr = np.random.randint(20, size=(SIZE, SIZE), dtype=np.uint16)
path = os.path.join(
train_mask_dir,
f"ref_agrifieldnet_competition_v1_labels_train_{id}_field_ids.tif",
)
with rasterio.open(path, "w", **profile) as src:
src.write(train_field_arr, 1)

for id in test_folder_ids:
test_field_arr = np.random.randint(10, 30, size=(SIZE, SIZE), dtype=np.uint16)
path = os.path.join(
test_field_dir,
f"ref_agrifieldnet_competition_v1_labels_test_{id}_field_ids.tif",
)
with rasterio.open(path, "w", **profile) as src:
src.write(test_field_arr, 1)


if __name__ == "__main__":
generate_test_data(os.getcwd())
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
77 changes: 77 additions & 0 deletions tests/datasets/test_agrifieldnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from rasterio.crs import CRS

from torchgeo.datasets import (
AgriFieldNet,
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
RGBBandsMissingError,
UnionDataset,
)


class TestAgriFieldNet:
@pytest.fixture
def dataset(self) -> AgriFieldNet:
path = os.path.join("tests", "data", "agrifieldnet")
transforms = nn.Identity()
return AgriFieldNet(paths=path, transforms=transforms)

def test_getitem(self, dataset: AgriFieldNet) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)

def test_and(self, dataset: AgriFieldNet) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: AgriFieldNet) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_already_downloaded(self, dataset: AgriFieldNet) -> None:
AgriFieldNet(paths=dataset.paths)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
AgriFieldNet(str(tmp_path))

def test_plot(self, dataset: AgriFieldNet) -> None:
x = dataset[dataset.bounds]
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: AgriFieldNet) -> None:
x = dataset[dataset.bounds]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()

def test_invalid_query(self, dataset: AgriFieldNet) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]

def test_rgb_bands_absent_plot(self, dataset: AgriFieldNet) -> None:
with pytest.raises(
RGBBandsMissingError, match="Dataset does not contain some of the RGB bands"
):
ds = AgriFieldNet(dataset.paths, bands=["B01", "B02", "B05"])
x = ds[ds.bounds]
ds.plot(x, suptitle="Test")
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .advance import ADVANCE
from .agb_live_woody_density import AbovegroundLiveWoodyBiomassDensity
from .agrifieldnet import AgriFieldNet
from .airphen import Airphen
from .astergdem import AsterGDEM
from .benin_cashews import BeninSmallHolderCashews
Expand Down Expand Up @@ -139,6 +140,7 @@
__all__ = (
# GeoDataset
"AbovegroundLiveWoodyBiomassDensity",
"AgriFieldNet",
"Airphen",
"AsterGDEM",
"CanadianBuildingFootprints",
Expand Down
Loading

0 comments on commit 8af188c

Please sign in to comment.