diff --git a/.github/dependabot.yml b/.github/dependabot.yml index eb0571076dc..e9796f63b42 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -20,6 +20,10 @@ updates: - "torch" - "torchvision" ignore: + # lightning 2.3+ contains known bugs related to YAML parsing + # https://github.com/Lightning-AI/pytorch-lightning/issues/19977 + - dependency-name: "lightning" + version: ">=2.3" # setuptools releases new versions almost daily - dependency-name: "setuptools" update-types: ["version-update:semver-patch"] diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index cfffdb4b850..0254d7c6edc 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.1.7 - name: Add label uses: actions/labeler@v5.0.0 with: diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 68cdbc9c693..67ebd42538e 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -12,9 +12,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.1.7 - name: Set up python - uses: actions/setup-python@v5.1.0 + uses: actions/setup-python@v5.1.1 with: python-version: "3.12" - name: Cache dependencies @@ -40,9 +40,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.1.7 - name: Set up python - uses: actions/setup-python@v5.1.0 + uses: actions/setup-python@v5.1.1 with: python-version: "3.12" - name: Cache dependencies diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml index 5133ce2bd49..860c820e389 100644 --- a/.github/workflows/style.yaml +++ b/.github/workflows/style.yaml @@ -14,9 +14,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.1.7 - name: Set up python - uses: actions/setup-python@v5.1.0 + uses: actions/setup-python@v5.1.1 with: python-version: "3.12" - name: Cache dependencies @@ -39,9 +39,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.1.7 - name: Set up python - uses: actions/setup-python@v5.1.0 + uses: actions/setup-python@v5.1.1 with: python-version: "3.12" - name: Cache dependencies @@ -66,9 +66,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.1.7 - name: Set up nodejs - uses: actions/setup-node@v4.0.2 + uses: actions/setup-node@v4.0.3 with: node-version: "20" cache: "npm" diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 7a7a0980c73..ee7c4af31c2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -20,9 +20,9 @@ jobs: python-version: ["3.10", "3.11", "3.12"] steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.1.7 - name: Set up python - uses: actions/setup-python@v5.1.0 + uses: actions/setup-python@v5.1.1 with: python-version: ${{ matrix.python-version }} - name: Cache dependencies @@ -57,7 +57,7 @@ jobs: pytest --cov=torchgeo --cov-report=xml --durations=10 python3 -m torchgeo --help - name: Report coverage - uses: codecov/codecov-action@v4.4.1 + uses: codecov/codecov-action@v4.5.0 with: token: ${{ secrets.CODECOV_TOKEN }} minimum: @@ -67,9 +67,9 @@ jobs: MPLBACKEND: Agg steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.1.7 - name: Set up python - uses: actions/setup-python@v5.1.0 + uses: actions/setup-python@v5.1.1 with: python-version: "3.10" - name: Cache dependencies @@ -96,7 +96,7 @@ jobs: pytest --cov=torchgeo --cov-report=xml --durations=10 python3 -m torchgeo --help - name: Report coverage - uses: codecov/codecov-action@v4.4.1 + uses: codecov/codecov-action@v4.5.0 with: token: ${{ secrets.CODECOV_TOKEN }} datasets: @@ -106,9 +106,9 @@ jobs: MPLBACKEND: Agg steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.1.7 - name: Set up python - uses: actions/setup-python@v5.1.0 + uses: actions/setup-python@v5.1.1 with: python-version: "3.12" - name: Cache dependencies @@ -129,7 +129,7 @@ jobs: pytest --cov=torchgeo --cov-report=xml --durations=10 python3 -m torchgeo --help - name: Report coverage - uses: codecov/codecov-action@v4.4.1 + uses: codecov/codecov-action@v4.5.0 with: token: ${{ secrets.CODECOV_TOKEN }} concurrency: diff --git a/.github/workflows/tutorials.yaml b/.github/workflows/tutorials.yaml index eeff9acbd03..dd5b83b9119 100644 --- a/.github/workflows/tutorials.yaml +++ b/.github/workflows/tutorials.yaml @@ -16,9 +16,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.1.7 - name: Set up python - uses: actions/setup-python@v5.1.0 + uses: actions/setup-python@v5.1.1 with: python-version: "3.12" - name: Cache dependencies diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 164803fb8bd..b9c3cb56eca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: - lightning>=2.0.9 - matplotlib>=3.8.1 - numpy>=1.22 - - pillow>=10.3.0 + - pillow>=10.4.0 - pytest>=6.1.2 - pyvista>=0.34.2 - scikit-image>=0.22.0 diff --git a/README.md b/README.md index 38883a212d7..1d8565e57a8 100644 --- a/README.md +++ b/README.md @@ -239,7 +239,7 @@ torchgeo fit --config config.yaml # Validate-only torchgeo validate --config config.yaml # Calculate and report test accuracy -torchgeo test --config config.yaml ckpt_path=... +torchgeo test --config config.yaml --ckpt_path=... ``` It can also be imported and used in a Python script if you need to extend it to add new features: diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index 2dac9021daa..dfc0ea5eb10 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -27,7 +27,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `NASA Marine Debris`_,OD,PlanetScope,"Apache-2.0",707,1,256x256,3,RGB `OSCD`_,CD,Sentinel-2,"CC-BY-4.0",24,2,"40--1,180",60,MSI `PASTIS`_,I,Sentinel-1/2,"CC-BY-4.0","2,433",19,128x128xT,10,MSI -`PatternNet`_,C,Google Earth,-,"30,400",38,256x256,0.06--5,RGB +`PatternNet`_,C,Google Earth,"CC-BY-4.0","30,400",38,256x256,0.06--5,RGB `Potsdam`_,S,Aerial,-,38,6,"6,000x6,000",0.05,MSI `QuakeSet`_,"C, R",Sentinel-1,"OpenRAIL","3,327",2,512x512,10,SAR `ReforesTree`_,"OD, R",Aerial,"CC-BY-4.0",100,6,"4,000x4,000",0.02,RGB diff --git a/docs/conf.py b/docs/conf.py index 36af7d3d275..ce97ea5a685 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -144,22 +144,6 @@ :class: colabbadge :alt: Open in Colab :target: {{ host }}/github/{{ repo }}/blob/{{ branch }}/{{ urlpath }} - -{% set host = "https://pccompute.westeurope.cloudapp.azure.com" %} -{% set host = host ~ "/compute/hub/user-redirect/git-pull" %} -{% set repo = "https%3A%2F%2Fgithub.com%2Fmicrosoft%2Ftorchgeo" %} -{% set urlpath = "tree%2Ftorchgeo%2Fdocs%2F" %} -{% set urlpath = urlpath ~ env.docname | replace("/", "%2F") ~ ".ipynb" %} -{% if "dev" in env.config.release %} - {% set branch = "main" %} -{% else %} - {% set branch = "releases%2Fv" ~ env.config.version %} -{% endif %} - -.. image:: https://img.shields.io/badge/-Open%20on%20Planetary%20Computer-blue - :class: colabbadge - :alt: Open on Planetary Computer - :target: {{ host }}?repo={{ repo }}&urlpath={{ urlpath }}&branch={{ branch }} """ # Disables requirejs in nbsphinx to enable compatibility with the pytorch_sphinx_theme diff --git a/experiments/ssl4eo/plot_example_predictions.py b/experiments/ssl4eo/plot_example_predictions.py index 596c8ea304d..96fe9103b9c 100755 --- a/experiments/ssl4eo/plot_example_predictions.py +++ b/experiments/ssl4eo/plot_example_predictions.py @@ -63,13 +63,9 @@ data = sample[key] if key == 'image': data = data[[2, 1, 0]].permute(1, 2, 0).numpy().astype('uint8') - Image.fromarray(data, 'RGB').save( # type: ignore[no-untyped-call] - f'{path}/{key}.png' - ) + Image.fromarray(data, 'RGB').save(f'{path}/{key}.png') else: data = data * 255 / 4 data = data.numpy().astype('uint8').squeeze() - Image.fromarray(data, 'L').save( # type: ignore[no-untyped-call] - f'{path}/{key}.png' - ) + Image.fromarray(data, 'L').save(f'{path}/{key}.png') i += 1 diff --git a/pyproject.toml b/pyproject.toml index 9707f472d8e..eb43ce1cbdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,14 +40,16 @@ dependencies = [ "einops>=0.3", # fiona 1.8.21+ required for Python 3.10 wheels "fiona>=1.8.21", - # kornia 0.7.2+ required for dict support in AugmentationSequential - "kornia>=0.7.2", + # kornia 0.7.3+ required for instance segmentation support in AugmentationSequential + "kornia>=0.7.3", # lightly 1.4.4+ required for MoCo v3 support # lightly 1.4.26 is incompatible with the version of timm required by smp # https://github.com/microsoft/torchgeo/issues/1824 "lightly>=1.4.4,!=1.4.26", # lightning 2+ required for LightningCLI args + sys.argv support - "lightning[pytorch-extra]>=2", + # lightning 2.3+ contains known bugs related to YAML parsing + # https://github.com/Lightning-AI/pytorch-lightning/issues/19977 + "lightning[pytorch-extra]>=2,<2.3", # matplotlib 3.5+ required for Python 3.10 wheels "matplotlib>=3.5", # numpy 1.21.2+ required by Python 3.10 wheels @@ -266,7 +268,7 @@ quote-style = "single" skip-magic-trailing-comma = true [tool.ruff.lint] -extend-select = ["D", "I", "UP"] +extend-select = ["D", "I", "NPY201", "UP"] [tool.ruff.lint.per-file-ignores] "docs/**" = ["D"] diff --git a/requirements/datasets.txt b/requirements/datasets.txt index 1c95448b161..0f54889c24a 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -1,11 +1,11 @@ # datasets h5py==3.11.0 -laspy==2.5.3 -opencv-python==4.9.0.80 -pycocotools==2.0.7 -pyvista==0.43.8 +laspy==2.5.4 +opencv-python==4.10.0.84 +pycocotools==2.0.8 +pyvista==0.44.0 radiant-mlhub==0.4.1 rarfile==4.2 -scikit-image==0.23.2 -scipy==1.13.1 +scikit-image==0.24.0 +scipy==1.14.0 zipfile-deflate64==0.2.0 diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index d3e72bb5b83..9475f2289f2 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -4,7 +4,7 @@ setuptools==61.0.0 # install einops==0.3.0 fiona==1.8.21 -kornia==0.7.2 +kornia==0.7.3 lightly==1.4.4 lightning[pytorch-extra]==2.0.0 matplotlib==3.5.0 diff --git a/requirements/package-lock.json b/requirements/package-lock.json index d2d97509fbb..371112f9d23 100644 --- a/requirements/package-lock.json +++ b/requirements/package-lock.json @@ -1,11 +1,17 @@ { + "name": "torchgeo", "lockfileVersion": 3, "requires": true, "packages": { + "": { + "dependencies": { + "prettier": ">=3.3.3" + } + }, "node_modules/prettier": { - "version": "3.2.5", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.2.5.tgz", - "integrity": "sha512-3/GWa9aOC0YeD7LUfvOG2NiDyhOWRvt1k+rcKhOuYnMY24iiCphgneUfJDyFXd6rZCAnuLBv6UeAULtrhT/F4A==", + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz", + "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==", "bin": { "prettier": "bin/prettier.cjs" }, diff --git a/requirements/package.json b/requirements/package.json index ebfe31dde86..0735e2f64ef 100644 --- a/requirements/package.json +++ b/requirements/package.json @@ -2,6 +2,6 @@ "name": "torchgeo", "private": "true", "dependencies": { - "prettier": ">=3.0.0" + "prettier": ">=3.3.3" } } diff --git a/requirements/required.txt b/requirements/required.txt index bc8fb922367..9fe1a24ce61 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -1,22 +1,22 @@ # setup -setuptools==70.0.0 +setuptools==70.3.0 # install einops==0.8.0 fiona==1.9.6 -kornia==0.7.2 -lightly==1.5.4 +kornia==0.7.3 +lightly==1.5.9 lightning[pytorch-extra]==2.2.5 matplotlib==3.9.0 numpy==1.26.4 pandas==2.2.2 -pillow==10.3.0 +pillow==10.4.0 pyproj==3.6.1 rasterio==1.3.10 -rtree==1.2.0 +rtree==1.3.0 segmentation-models-pytorch==0.3.3 -shapely==2.0.4 +shapely==2.0.5 timm==0.9.2 -torch==2.3.0 +torch==2.3.1 torchmetrics==1.4.0.post0 -torchvision==0.18.0 +torchvision==0.18.1 diff --git a/requirements/style.txt b/requirements/style.txt index 82bf7c8d526..9330d471f57 100644 --- a/requirements/style.txt +++ b/requirements/style.txt @@ -1,3 +1,3 @@ # style -mypy==1.10.0 -ruff==0.4.6 +mypy==1.10.1 +ruff==0.5.2 diff --git a/requirements/tests.txt b/requirements/tests.txt index c4770eef925..b0a6b1f51b3 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,4 +1,4 @@ # tests -nbmake==1.5.3 -pytest==8.2.1 +nbmake==1.5.4 +pytest==8.2.2 pytest-cov==5.0.0 diff --git a/tests/data/README.md b/tests/data/README.md index 1d95c728d6d..312bce2a8e8 100644 --- a/tests/data/README.md +++ b/tests/data/README.md @@ -20,7 +20,7 @@ with rio.open(os.path.join(ROOT, FILENAME), "r") as src: dtype = src.profile["dtype"] Z = np.random.randint(np.iinfo(dtype).max, size=(SIZE, SIZE), dtype=dtype) with rio.open(FILENAME, "w", **src.profile) as dst: - for i in dst.profile.indexes: + for i in dst.indexes: dst.write(Z, i) ``` diff --git a/tests/data/cv4a_kenya_crop_type/FieldIds.csv b/tests/data/cv4a_kenya_crop_type/FieldIds.csv new file mode 100644 index 00000000000..04ff33b2500 --- /dev/null +++ b/tests/data/cv4a_kenya_crop_type/FieldIds.csv @@ -0,0 +1,5 @@ +train,test +1,2 +3,4 +5 +6 diff --git a/tests/data/cv4a_kenya_crop_type/data.py b/tests/data/cv4a_kenya_crop_type/data.py new file mode 100755 index 00000000000..e55ffa45b64 --- /dev/null +++ b/tests/data/cv4a_kenya_crop_type/data.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +from PIL import Image + +DTYPE = np.float32 +SIZE = 2 + +np.random.seed(0) + +all_bands = ( + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', + 'CLD', +) + +for tile in range(1): + directory = os.path.join('data', str(tile)) + os.makedirs(directory, exist_ok=True) + + arr = np.random.randint(np.iinfo(np.int32).max, size=(SIZE, SIZE), dtype=np.int32) + img = Image.fromarray(arr) + img.save(os.path.join(directory, f'{tile}_field_id.tif')) + + arr = np.random.randint(np.iinfo(np.uint8).max, size=(SIZE, SIZE), dtype=np.uint8) + img = Image.fromarray(arr) + img.save(os.path.join(directory, f'{tile}_label.tif')) + + for date in ['20190606']: + directory = os.path.join(directory, date) + os.makedirs(directory, exist_ok=True) + + for band in all_bands: + arr = np.random.rand(SIZE, SIZE).astype(DTYPE) * np.finfo(DTYPE).max + img = Image.fromarray(arr) + img.save(os.path.join(directory, f'{tile}_{band}_{date}.tif')) diff --git a/tests/data/cv4a_kenya_crop_type/data/0/0_field_id.tif b/tests/data/cv4a_kenya_crop_type/data/0/0_field_id.tif new file mode 100644 index 00000000000..f72a6772091 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/0_field_id.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/0_label.tif b/tests/data/cv4a_kenya_crop_type/data/0/0_label.tif new file mode 100644 index 00000000000..c0555ad107b Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/0_label.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B01_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B01_20190606.tif new file mode 100644 index 00000000000..1311d977f55 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B01_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B02_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B02_20190606.tif new file mode 100644 index 00000000000..ad41e11ea7b Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B02_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B03_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B03_20190606.tif new file mode 100644 index 00000000000..294e70e13f8 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B03_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B04_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B04_20190606.tif new file mode 100644 index 00000000000..704c8dfc23d Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B04_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B05_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B05_20190606.tif new file mode 100644 index 00000000000..a0aa5478a3a Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B05_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B06_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B06_20190606.tif new file mode 100644 index 00000000000..834e92f43b5 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B06_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B07_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B07_20190606.tif new file mode 100644 index 00000000000..58f58df0767 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B07_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B08_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B08_20190606.tif new file mode 100644 index 00000000000..f534bde3167 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B08_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B09_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B09_20190606.tif new file mode 100644 index 00000000000..b931b7189b0 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B09_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B11_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B11_20190606.tif new file mode 100644 index 00000000000..ea661cbc40e Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B11_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B12_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B12_20190606.tif new file mode 100644 index 00000000000..017b1714532 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B12_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B8A_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B8A_20190606.tif new file mode 100644 index 00000000000..1e3f7ce38b8 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B8A_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_CLD_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_CLD_20190606.tif new file mode 100644 index 00000000000..1ec85420866 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_CLD_20190606.tif differ diff --git a/tests/data/cyclone/data.py b/tests/data/cyclone/data.py new file mode 100755 index 00000000000..2ea0f7a425a --- /dev/null +++ b/tests/data/cyclone/data.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import pandas as pd +from PIL import Image + +DTYPE = np.uint8 +SIZE = 2 + +np.random.seed(0) + +for split in ['train', 'test']: + os.makedirs(split, exist_ok=True) + + filename = split + if split == 'train': + filename = 'training' + + features = pd.read_csv(f'{filename}_set_features.csv') + for image_id, _, _, ocean in features.values: + size = (SIZE, SIZE) + if ocean % 2 == 0: + size = (SIZE * 2, SIZE * 2, 3) + + arr = np.random.randint(np.iinfo(DTYPE).max, size=size, dtype=DTYPE) + img = Image.fromarray(arr) + img.save(os.path.join(split, f'{image_id}.jpg')) diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels.tar.gz deleted file mode 100644 index cbfa3779d9a..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/collection.json deleted file mode 100644 index a5692a66e5e..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_test_labels_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_a_000/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_a_000/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_a_000/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_b_001/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_b_001/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_b_001/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_c_002/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_c_002/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_c_002/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_d_003/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_d_003/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_d_003/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_e_004/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_e_004/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_e_004/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_test_source.tar.gz deleted file mode 100644 index 7a8162fafdf..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/collection.json deleted file mode 100644 index 97c44e9907a..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_test_source_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/features.json deleted file mode 100644 index 83438ddffa4..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "a", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/features.json deleted file mode 100644 index 13f4a63afaa..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "b", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/features.json deleted file mode 100644 index d8671e26416..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "c", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/features.json deleted file mode 100644 index a6eebd660e0..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "d", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/features.json deleted file mode 100644 index 90267dc6f1f..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "e", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels.tar.gz deleted file mode 100644 index 83f9138674e..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/collection.json deleted file mode 100644 index 834d293998a..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_train_labels_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_a_000/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_a_000/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_a_000/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_b_001/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_b_001/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_b_001/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_c_002/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_c_002/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_c_002/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_d_003/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_d_003/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_d_003/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_e_004/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_e_004/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_e_004/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_train_source.tar.gz deleted file mode 100644 index b3f019e97c7..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/collection.json deleted file mode 100644 index a03e0c77a19..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_train_source_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/features.json deleted file mode 100644 index 83438ddffa4..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "a", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/image.jpg deleted file mode 100644 index 79c38f2a929..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/features.json deleted file mode 100644 index 13f4a63afaa..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "b", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/features.json deleted file mode 100644 index d8671e26416..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "c", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/image.jpg deleted file mode 100644 index 79c38f2a929..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/features.json deleted file mode 100644 index a6eebd660e0..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "d", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/image.jpg deleted file mode 100644 index 79c38f2a929..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/features.json deleted file mode 100644 index 90267dc6f1f..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "e", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/image.jpg deleted file mode 100644 index 79c38f2a929..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/test/aaa_000.jpg b/tests/data/cyclone/test/aaa_000.jpg new file mode 100644 index 00000000000..f4d039da97c Binary files /dev/null and b/tests/data/cyclone/test/aaa_000.jpg differ diff --git a/tests/data/cyclone/test/bbb_111.jpg b/tests/data/cyclone/test/bbb_111.jpg new file mode 100644 index 00000000000..0d8e7a84a23 Binary files /dev/null and b/tests/data/cyclone/test/bbb_111.jpg differ diff --git a/tests/data/cyclone/test/ccc_222.jpg b/tests/data/cyclone/test/ccc_222.jpg new file mode 100644 index 00000000000..ebd3ba67c09 Binary files /dev/null and b/tests/data/cyclone/test/ccc_222.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/image.jpg b/tests/data/cyclone/test/ddd_333.jpg similarity index 74% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/image.jpg rename to tests/data/cyclone/test/ddd_333.jpg index 79c38f2a929..575d5a5c69f 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/image.jpg and b/tests/data/cyclone/test/ddd_333.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/image.jpg b/tests/data/cyclone/test/eee_444.jpg similarity index 82% rename from tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/image.jpg rename to tests/data/cyclone/test/eee_444.jpg index 77c95fe8774..0cd10728e84 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/image.jpg and b/tests/data/cyclone/test/eee_444.jpg differ diff --git a/tests/data/cyclone/test_set_features.csv b/tests/data/cyclone/test_set_features.csv new file mode 100644 index 00000000000..dce291b0b5c --- /dev/null +++ b/tests/data/cyclone/test_set_features.csv @@ -0,0 +1,6 @@ +Image ID,Storm ID,Relative Time,Ocean +aaa_000,aaa,0,0 +bbb_111,bbb,1,1 +ccc_222,ccc,2,2 +ddd_333,ddd,3,3 +eee_444,eee,4,4 diff --git a/tests/data/cyclone/test_set_labels.csv b/tests/data/cyclone/test_set_labels.csv new file mode 100644 index 00000000000..8aa2d7c7f67 --- /dev/null +++ b/tests/data/cyclone/test_set_labels.csv @@ -0,0 +1,6 @@ +Image ID,Wind Speed +aaa_000,0 +bbb_111,1 +ccc_222,2 +ddd_333,3 +eee_444,4 diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/image.jpg b/tests/data/cyclone/train/fff_555.jpg similarity index 73% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/image.jpg rename to tests/data/cyclone/train/fff_555.jpg index 79c38f2a929..15225859b03 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/image.jpg and b/tests/data/cyclone/train/fff_555.jpg differ diff --git a/tests/data/cyclone/train/ggg_666.jpg b/tests/data/cyclone/train/ggg_666.jpg new file mode 100644 index 00000000000..3065b52a80b Binary files /dev/null and b/tests/data/cyclone/train/ggg_666.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/image.jpg b/tests/data/cyclone/train/hhh_777.jpg similarity index 75% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/image.jpg rename to tests/data/cyclone/train/hhh_777.jpg index 79c38f2a929..877ac76c481 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/image.jpg and b/tests/data/cyclone/train/hhh_777.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/image.jpg b/tests/data/cyclone/train/iii_888.jpg similarity index 82% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/image.jpg rename to tests/data/cyclone/train/iii_888.jpg index 77c95fe8774..731128b8a0c 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/image.jpg and b/tests/data/cyclone/train/iii_888.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/image.jpg b/tests/data/cyclone/train/jjj_999.jpg similarity index 75% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/image.jpg rename to tests/data/cyclone/train/jjj_999.jpg index 79c38f2a929..8fda5ace924 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/image.jpg and b/tests/data/cyclone/train/jjj_999.jpg differ diff --git a/tests/data/cyclone/training_set_features.csv b/tests/data/cyclone/training_set_features.csv new file mode 100644 index 00000000000..56df786e8ef --- /dev/null +++ b/tests/data/cyclone/training_set_features.csv @@ -0,0 +1,6 @@ +Image ID,Storm ID,Relative Time,Ocean +fff_555,fff,5,5 +ggg_666,ggg,6,6 +hhh_777,hhh,7,7 +iii_888,iii,8,8 +jjj_999,jjj,9,9 diff --git a/tests/data/cyclone/training_set_labels.csv b/tests/data/cyclone/training_set_labels.csv new file mode 100644 index 00000000000..5a8bbabce8c --- /dev/null +++ b/tests/data/cyclone/training_set_labels.csv @@ -0,0 +1,6 @@ +Image ID,Wind Speed +fff_555,5 +ggg_666,6 +hhh_777,7 +iii_888,8 +jjj_999,9 diff --git a/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_labels.tar.gz b/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_labels.tar.gz deleted file mode 100644 index 1c642bf9c73..00000000000 Binary files a/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_labels.tar.gz and /dev/null differ diff --git a/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_source.tar.gz b/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_source.tar.gz deleted file mode 100644 index f5e0e289137..00000000000 Binary files a/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_source.tar.gz and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/data.py b/tests/data/ref_cloud_cover_detection_challenge_v1/data.py index e8a771e0fa5..1523af6a14e 100755 --- a/tests/data/ref_cloud_cover_detection_challenge_v1/data.py +++ b/tests/data/ref_cloud_cover_detection_challenge_v1/data.py @@ -1,275 +1,42 @@ +#!/usr/bin/env python3 + # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import os -from datetime import datetime as dt -from pathlib import Path import numpy as np import rasterio -from pystac import ( - Asset, - CatalogType, - Collection, - Extent, - Item, - Link, - MediaType, - SpatialExtent, - TemporalExtent, -) -from pystac.extensions.eo import Band, EOExtension -from pystac.extensions.label import ( - LabelClasses, - LabelCount, - LabelExtension, - LabelOverview, - LabelType, -) +from rasterio import Affine from rasterio.crs import CRS -from rasterio.transform import Affine - -np.random.seed(0) -SIZE = 512 -BANDS = ['B02', 'B03', 'B04', 'B08'] +SIZE = 2 +DTYPE = np.uint16 -SOURCE_COLLECTION_ID = 'ref_cloud_cover_detection_challenge_v1_test_source' -SOURCE_ITEM_ID = 'ref_cloud_cover_detection_challenge_v1_test_source_aaaa' -LABEL_COLLECTION_ID = 'ref_cloud_cover_detection_challenge_v1_test_labels' -LABEL_ITEM_ID = 'ref_cloud_cover_detection_challenge_v1_test_labels_aaaa' +np.random.seed(0) -# geometry used by both source and label items -TEST_GEOMETRY = { - 'type': 'Polygon', - 'coordinates': [ - [ - [137.86580132892396, -29.52744848758255], - [137.86450090473795, -29.481297003404038], - [137.91724642199793, -29.48015007212528], - [137.9185707094313, -29.526299409555623], - [137.86580132892396, -29.52744848758255], - ] - ], +splits = {'train': 'public', 'test': 'private'} +chip_ids = ['aaaa'] +all_bands = ['B02', 'B03', 'B04', 'B08'] +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_epsg(32753), + 'transform': Affine(10.0, 0.0, 777760.0, 0.0, -10.0, 6735270.0), } - -# bbox used by both source and label items -TEST_BBOX = [ - 137.86450090473795, - -29.52744848758255, - 137.9185707094313, - -29.48015007212528, -] - -# sentinel-2 bands for EO extension -S2_BANDS = [ - Band.create(name='B02', common_name='blue', description='Blue'), - Band.create(name='B03', common_name='green', description='Green'), - Band.create(name='B04', common_name='red', description='Red'), - Band.create(name='B08', common_name='nir', description='NIR'), -] - -# class map for overviews -CLASS_COUNT_MAP = {'0': 'no cloud', '1': 'cloud'} - -# define the spatial and temporal extent of collections -TEST_EXTENT = Extent( - spatial=SpatialExtent( - bboxes=[ - [ - -80.05464265420176, - -53.31380701212582, - 151.75593282192196, - 35.199126843018696, - ] - ] - ), - temporal=TemporalExtent( - intervals=[ - [ - dt.strptime('2018-02-18', '%Y-%m-%d'), - dt.strptime('2020-09-13', '%Y-%m-%d'), - ] - ] - ), -) - - -def create_raster(path: str, dtype: str, num_channels: int, collection: str) -> None: - if not os.path.exists(os.path.split(path)[0]): - Path(os.path.split(path)[0]).mkdir(parents=True) - - profile = {} - profile['driver'] = 'GTiff' - profile['dtype'] = dtype - profile['count'] = num_channels - profile['crs'] = CRS.from_epsg(32753) - profile['transform'] = Affine(1.0, 0.0, 777760.0, 0.0, -10.0, 6735270.0) - profile['height'] = SIZE - profile['width'] = SIZE - profile['compress'] = 'lzw' - profile['predictor'] = 2 - - if collection == 'source': - if 'float' in profile['dtype']: - Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) - else: - Z = np.random.randint( - np.iinfo(profile['dtype']).max, - size=(SIZE, SIZE), - dtype=profile['dtype'], - ) - elif collection == 'labels': - Z = np.random.randint(0, 2, (SIZE, SIZE)).astype(profile['dtype']) - - with rasterio.open(path, 'w', **profile) as src: - for i in range(1, profile['count'] + 1): - src.write(Z, i) - - -def create_source_item() -> Item: - # instantiate source Item - test_source_item = Item( - id=SOURCE_ITEM_ID, - geometry=TEST_GEOMETRY, - bbox=TEST_BBOX, - datetime=dt.strptime('2020-06-03', '%Y-%m-%d'), - properties={}, - ) - - # add Asset with EO Extension for each S2 band - for band in BANDS: - img_path = os.path.join( - os.getcwd(), SOURCE_COLLECTION_ID, SOURCE_ITEM_ID, f'{band}.tif' - ) - image_asset = Asset(href=img_path, media_type=MediaType.GEOTIFF) - eo_asset_ext = EOExtension.ext(image_asset) - - for s2_band in S2_BANDS: - if s2_band.name == band: - eo_asset_ext.apply(bands=[s2_band]) - test_source_item.add_asset(key=band, asset=image_asset) - - eo_image_ext = EOExtension.ext(test_source_item, add_if_missing=True) - eo_image_ext.apply(bands=S2_BANDS) - - return test_source_item - - -def get_class_label_list(overview: LabelOverview) -> LabelClasses: - label_list = [d['name'] for d in overview.properties['counts']] - label_classes = LabelClasses.create(classes=label_list, name='labels') - return label_classes - - -def get_item_class_overview(label_type: LabelType, asset_path: str) -> LabelOverview: - """Takes a path to an asset based on type and returns the class label - overview object - - Args: - label_type: LabelType - the type of label, either RASTER or VECTOR - asset_path: str - path to the asset to read in either a raster image or - geojson vector - - Returns: - overview: LabelOverview - the STAC LabelOverview object containing label classes - - """ - - count_list = [] - - img_arr = rasterio.open(asset_path).read() - value_count = np.unique(img_arr.flatten(), return_counts=True) - - for ix, classy in enumerate(value_count[0]): - if classy > 0: - label_count = LabelCount.create( - name=CLASS_COUNT_MAP[str(int(classy))], count=int(value_count[1][ix]) - ) - count_list.append(label_count) - - overview = LabelOverview(properties={}) - overview.apply(property_key='labels', counts=count_list) - - return overview - - -def create_label_item() -> Item: - # instantiate label Item - test_label_item = Item( - id=LABEL_ITEM_ID, - geometry=TEST_GEOMETRY, - bbox=TEST_BBOX, - datetime=dt.strptime('2020-06-03', '%Y-%m-%d'), - properties={}, - ) - - label_overview = get_item_class_overview(LabelType.RASTER, label_path) - label_list = get_class_label_list(label_overview) - - label_ext = LabelExtension.ext(test_label_item, add_if_missing=True) - label_ext.apply( - label_description='Sentinel-2 Cloud Cover Segmentation Test Labels', - label_type=LabelType.RASTER, - label_classes=[label_list], - label_overviews=[label_overview], - ) - - label_asset = Asset(href=label_path, media_type=MediaType.GEOTIFF) - test_label_item.add_asset(key='labels', asset=label_asset) - - return test_label_item - - -if __name__ == '__main__': - # create a geotiff for each s2 band - for b in BANDS: - tif_path = os.path.join( - os.getcwd(), SOURCE_COLLECTION_ID, SOURCE_ITEM_ID, f'{b}.tif' - ) - create_raster(tif_path, 'uint8', 1, 'source') - - # create a geotiff for label - label_path = os.path.join( - os.getcwd(), LABEL_COLLECTION_ID, LABEL_ITEM_ID, 'labels.tif' - ) - create_raster(label_path, 'uint8', 1, 'labels') - - # instantiate the source Collection - test_source_collection = Collection( - id=SOURCE_COLLECTION_ID, - description='Test Source Collection for Torchgo Cloud Cover Detection Dataset', - extent=TEST_EXTENT, - catalog_type=CatalogType.RELATIVE_PUBLISHED, - license='CC-BY-4.0', - ) - - source_item = create_source_item() - test_source_collection.add_item(source_item) - - test_source_collection.normalize_hrefs( - os.path.join(os.getcwd(), SOURCE_COLLECTION_ID) - ) - test_source_collection.make_all_asset_hrefs_relative() - test_source_collection.save(catalog_type=CatalogType.SELF_CONTAINED) - - # instantiate the label Collection - test_label_collection = Collection( - id=LABEL_COLLECTION_ID, - description='Test Label Collection for Torchgo Cloud Cover Detection Dataset', - extent=TEST_EXTENT, - catalog_type=CatalogType.RELATIVE_PUBLISHED, - license='CC-BY-4.0', - ) - - label_item = create_label_item() - label_item.add_link( - Link(rel='source', target=source_item, media_type=MediaType.GEOTIFF) - ) - test_label_collection.add_item(label_item) - - test_label_collection.normalize_hrefs( - os.path.join(os.getcwd(), LABEL_COLLECTION_ID) - ) - test_label_collection.make_all_asset_hrefs_relative() - test_label_collection.save(catalog_type=CatalogType.SELF_CONTAINED) +Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE) + +for split, directory in splits.items(): + for chip_id in chip_ids: + path = os.path.join(directory, f'{split}_features', chip_id) + os.makedirs(path, exist_ok=True) + for band in all_bands: + with rasterio.open(os.path.join(path, f'{band}.tif'), 'w', **profile) as f: + f.write(Z, 1) + path = os.path.join(directory, f'{split}_labels') + os.makedirs(path, exist_ok=True) + with rasterio.open(os.path.join(path, f'{chip_id}.tif'), 'w', **profile) as f: + f.write(Z, 1) diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B02.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B02.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B02.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B03.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B03.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B03.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B04.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B04.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B04.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B08.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B08.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B08.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_labels/aaaa.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_labels/aaaa.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_labels/aaaa.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_metadata.csv b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_metadata.csv new file mode 100644 index 00000000000..17c43ad2cef --- /dev/null +++ b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_metadata.csv @@ -0,0 +1,2 @@ +chip_id,location,datetime +aaaa,Australia - Central East,2024-06-013T00:00:00Z diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B02.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B02.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B02.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B03.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B03.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B03.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B04.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B04.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B04.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B08.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B08.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B08.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_labels/aaaa.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_labels/aaaa.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_labels/aaaa.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_metadata.csv b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_metadata.csv new file mode 100644 index 00000000000..17c43ad2cef --- /dev/null +++ b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_metadata.csv @@ -0,0 +1,2 @@ +chip_id,location,datetime +aaaa,Australia - Central East,2024-06-013T00:00:00Z diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz deleted file mode 100644 index 8aa7da7a185..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/collection.json b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/collection.json deleted file mode 100644 index 3f50cb6b5db..00000000000 --- a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/collection.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "type": "Collection", - "id": "ref_cloud_cover_detection_challenge_v1_test_labels", - "stac_version": "1.0.0", - "description": "Test Label Collection for Torchgo Cloud Cover Detection Dataset", - "links": [ - { - "rel": "root", - "href": "./collection.json", - "type": "application/json" - }, - { - "rel": "item", - "href": "./ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa.json", - "type": "application/json" - } - ], - "stac_extensions": [], - "extent": { - "spatial": { - "bbox": [ - [ - -80.05464265420176, - -53.31380701212582, - 151.75593282192196, - 35.199126843018696 - ] - ] - }, - "temporal": { - "interval": [ - [ - "2018-02-18T00:00:00Z", - "2020-09-13T00:00:00Z" - ] - ] - } - }, - "license": "CC-BY-4.0" -} \ No newline at end of file diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/labels.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/labels.tif deleted file mode 100644 index 0181ba4f573..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/labels.tif and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa.json b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa.json deleted file mode 100644 index 7633d8b46d1..00000000000 --- a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa.json +++ /dev/null @@ -1,95 +0,0 @@ -{ - "type": "Feature", - "stac_version": "1.0.0", - "id": "ref_cloud_cover_detection_challenge_v1_test_labels_aaaa", - "properties": { - "label:description": "Sentinel-2 Cloud Cover Segmentation Test Labels", - "label:type": "raster", - "label:properties": null, - "label:classes": [ - { - "classes": [ - "cloud" - ], - "name": "labels" - } - ], - "label:overviews": [ - { - "property_key": "labels", - "counts": [ - { - "name": "cloud", - "count": 130696 - } - ] - } - ], - "datetime": "2020-06-03T00:00:00Z" - }, - "geometry": { - "type": "Polygon", - "coordinates": [ - [ - [ - 137.86580132892396, - -29.52744848758255 - ], - [ - 137.86450090473795, - -29.481297003404038 - ], - [ - 137.91724642199793, - -29.48015007212528 - ], - [ - 137.9185707094313, - -29.526299409555623 - ], - [ - 137.86580132892396, - -29.52744848758255 - ] - ] - ] - }, - "links": [ - { - "rel": "source", - "href": "../../ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/ref_cloud_cover_detection_challenge_v1_test_source_aaaa.json", - "type": "image/tiff; application=geotiff" - }, - { - "rel": "root", - "href": "../collection.json", - "type": "application/json" - }, - { - "rel": "collection", - "href": "../collection.json", - "type": "application/json" - }, - { - "rel": "parent", - "href": "../collection.json", - "type": "application/json" - } - ], - "assets": { - "labels": { - "href": "./labels.tif", - "type": "image/tiff; application=geotiff" - } - }, - "bbox": [ - 137.86450090473795, - -29.52744848758255, - 137.9185707094313, - -29.48015007212528 - ], - "stac_extensions": [ - "https://stac-extensions.github.io/label/v1.0.1/schema.json" - ], - "collection": "ref_cloud_cover_detection_challenge_v1_test_labels" -} \ No newline at end of file diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source.tar.gz b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source.tar.gz deleted file mode 100644 index b65bd0849b5..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source.tar.gz and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/collection.json b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/collection.json deleted file mode 100644 index 3f19bed1e2b..00000000000 --- a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/collection.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "type": "Collection", - "id": "ref_cloud_cover_detection_challenge_v1_test_source", - "stac_version": "1.0.0", - "description": "Test Source Collection for Torchgo Cloud Cover Detection Dataset", - "links": [ - { - "rel": "root", - "href": "./collection.json", - "type": "application/json" - }, - { - "rel": "item", - "href": "./ref_cloud_cover_detection_challenge_v1_test_source_aaaa/ref_cloud_cover_detection_challenge_v1_test_source_aaaa.json", - "type": "application/json" - } - ], - "stac_extensions": [], - "extent": { - "spatial": { - "bbox": [ - [ - -80.05464265420176, - -53.31380701212582, - 151.75593282192196, - 35.199126843018696 - ] - ] - }, - "temporal": { - "interval": [ - [ - "2018-02-18T00:00:00Z", - "2020-09-13T00:00:00Z" - ] - ] - } - }, - "license": "CC-BY-4.0" -} \ No newline at end of file diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B02.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B02.tif deleted file mode 100644 index 5f23bc2b562..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B02.tif and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B03.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B03.tif deleted file mode 100644 index f143ae2c3fb..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B03.tif and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B04.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B04.tif deleted file mode 100644 index b1d91415d52..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B04.tif and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B08.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B08.tif deleted file mode 100644 index 111b1d7af26..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B08.tif and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/ref_cloud_cover_detection_challenge_v1_test_source_aaaa.json b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/ref_cloud_cover_detection_challenge_v1_test_source_aaaa.json deleted file mode 100644 index aebe445f34a..00000000000 --- a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/ref_cloud_cover_detection_challenge_v1_test_source_aaaa.json +++ /dev/null @@ -1,130 +0,0 @@ -{ - "type": "Feature", - "stac_version": "1.0.0", - "id": "ref_cloud_cover_detection_challenge_v1_test_source_aaaa", - "properties": { - "eo:bands": [ - { - "name": "B02", - "common_name": "blue", - "description": "Blue" - }, - { - "name": "B03", - "common_name": "green", - "description": "Green" - }, - { - "name": "B04", - "common_name": "red", - "description": "Red" - }, - { - "name": "B08", - "common_name": "nir", - "description": "NIR" - } - ], - "datetime": "2020-06-03T00:00:00Z" - }, - "geometry": { - "type": "Polygon", - "coordinates": [ - [ - [ - 137.86580132892396, - -29.52744848758255 - ], - [ - 137.86450090473795, - -29.481297003404038 - ], - [ - 137.91724642199793, - -29.48015007212528 - ], - [ - 137.9185707094313, - -29.526299409555623 - ], - [ - 137.86580132892396, - -29.52744848758255 - ] - ] - ] - }, - "links": [ - { - "rel": "root", - "href": "../collection.json", - "type": "application/json" - }, - { - "rel": "collection", - "href": "../collection.json", - "type": "application/json" - }, - { - "rel": "parent", - "href": "../collection.json", - "type": "application/json" - } - ], - "assets": { - "B02": { - "href": "./B02.tif", - "type": "image/tiff; application=geotiff", - "eo:bands": [ - { - "name": "B02", - "common_name": "blue", - "description": "Blue" - } - ] - }, - "B03": { - "href": "./B03.tif", - "type": "image/tiff; application=geotiff", - "eo:bands": [ - { - "name": "B03", - "common_name": "green", - "description": "Green" - } - ] - }, - "B04": { - "href": "./B04.tif", - "type": "image/tiff; application=geotiff", - "eo:bands": [ - { - "name": "B04", - "common_name": "red", - "description": "Red" - } - ] - }, - "B08": { - "href": "./B08.tif", - "type": "image/tiff; application=geotiff", - "eo:bands": [ - { - "name": "B08", - "common_name": "nir", - "description": "NIR" - } - ] - } - }, - "bbox": [ - 137.86450090473795, - -29.52744848758255, - 137.9185707094313, - -29.48015007212528 - ], - "stac_extensions": [ - "https://stac-extensions.github.io/eo/v1.0.0/schema.json" - ], - "collection": "ref_cloud_cover_detection_challenge_v1_test_source" -} \ No newline at end of file diff --git a/tests/data/rwanda_field_boundary/data.py b/tests/data/rwanda_field_boundary/data.py old mode 100644 new mode 100755 index a3522e8c962..bf9954e8935 --- a/tests/data/rwanda_field_boundary/data.py +++ b/tests/data/rwanda_field_boundary/data.py @@ -3,99 +3,46 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import hashlib import os -import shutil import numpy as np import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12') all_bands = ('B01', 'B02', 'B03', 'B04') SIZE = 32 -NUM_SAMPLES = 5 +DTYPE = np.uint16 +NUM_SAMPLES = 1 np.random.seed(0) - -def create_mask(fn: str) -> None: - profile = { - 'driver': 'GTiff', - 'dtype': 'uint8', - 'nodata': 0.0, - 'width': SIZE, - 'height': SIZE, - 'count': 1, - 'crs': 'epsg:3857', - 'compress': 'lzw', - 'predictor': 2, - 'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), - 'blockysize': 32, - 'tiled': False, - 'interleave': 'band', - } - with rasterio.open(fn, 'w', **profile) as f: - f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint8), 1) - - -def create_img(fn: str) -> None: - profile = { - 'driver': 'GTiff', - 'dtype': 'uint16', - 'nodata': 0.0, - 'width': SIZE, - 'height': SIZE, - 'count': 1, - 'crs': 'epsg:3857', - 'compress': 'lzw', - 'predictor': 2, - 'blockysize': 16, - 'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), - 'tiled': False, - 'interleave': 'band', - } - with rasterio.open(fn, 'w', **profile) as f: - f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint16), 1) - - -if __name__ == '__main__': - # Train and test images - for split in ('train', 'test'): - for i in range(NUM_SAMPLES): - for date in dates: - directory = os.path.join( - f'nasa_rwanda_field_boundary_competition_source_{split}', - f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501 - ) - os.makedirs(directory, exist_ok=True) - for band in all_bands: - create_img(os.path.join(directory, f'{band}.tif')) - - # Create collections.json, this isn't used by the dataset but is checked to - # exist - with open( - f'nasa_rwanda_field_boundary_competition_source_{split}/collections.json', - 'w', - ) as f: - f.write('Not used') - - # Train labels - for i in range(NUM_SAMPLES): - directory = os.path.join( - 'nasa_rwanda_field_boundary_competition_labels_train', - f'nasa_rwanda_field_boundary_competition_labels_train_{i:02d}', - ) - os.makedirs(directory, exist_ok=True) - create_mask(os.path.join(directory, 'raster_labels.tif')) - - # Create directories and compute checksums - for filename in [ - 'nasa_rwanda_field_boundary_competition_source_train', - 'nasa_rwanda_field_boundary_competition_source_test', - 'nasa_rwanda_field_boundary_competition_labels_train', - ]: - shutil.make_archive(filename, 'gztar', '.', filename) - # Compute checksums - with open(f'{filename}.tar.gz', 'rb') as f: - md5 = hashlib.md5(f.read()).hexdigest() - print(f'{filename}: {md5}') +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_epsg(3857), + 'transform': Affine( + 4.77731426716, 0.0, 3374518.037700199, 0.0, -4.77731426716, -168438.54642526805 + ), +} +Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE) + +for sample in range(NUM_SAMPLES): + for split in ['train', 'test']: + for date in dates: + path = os.path.join('source', split, date) + os.makedirs(path, exist_ok=True) + for band in all_bands: + file = os.path.join(path, f'{sample:02}_{band}.tif') + with rasterio.open(file, 'w', **profile) as src: + src.write(Z, 1) + + path = os.path.join('labels', 'train') + os.makedirs(path, exist_ok=True) + file = os.path.join(path, f'{sample:02}.tif') + with rasterio.open(file, 'w', **profile) as src: + src.write(Z, 1) diff --git a/tests/data/rwanda_field_boundary/labels/train/00.tif b/tests/data/rwanda_field_boundary/labels/train/00.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/labels/train/00.tif differ diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz deleted file mode 100644 index ffa98bb53d6..00000000000 Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz and /dev/null differ diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz deleted file mode 100644 index a834f66bf38..00000000000 Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz and /dev/null differ diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz deleted file mode 100644 index 8239f70c200..00000000000 Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz and /dev/null differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif differ diff --git a/tests/data/technoserve-cashew-benin/data.py b/tests/data/technoserve-cashew-benin/data.py new file mode 100755 index 00000000000..7d0ec9d58bb --- /dev/null +++ b/tests/data/technoserve-cashew-benin/data.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine + +DTYPE = np.uint16 +SIZE = 2 + +np.random.seed(0) + +dates = ('00_20191105',) +all_bands = ( + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', + 'CLD', +) +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_epsg(32631), + 'transform': Affine( + 10.002549584378608, + 0.0, + 440853.29890114715, + 0.0, + -9.99842989423825, + 1012804.082877621, + ), +} + +for date in dates: + os.makedirs(os.path.join('imagery', '00', date), exist_ok=True) + for band in all_bands: + Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE) + path = os.path.join('imagery', '00', date, f'{date}_{band}_10m.tif') + with rasterio.open(path, 'w', **profile) as src: + src.write(Z, 1) diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B01_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B01_10m.tif new file mode 100644 index 00000000000..e459a3a490d Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B01_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B02_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B02_10m.tif new file mode 100644 index 00000000000..fd4ca7ce56b Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B02_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B03_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B03_10m.tif new file mode 100644 index 00000000000..33b458bef62 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B03_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B04_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B04_10m.tif new file mode 100644 index 00000000000..76ca0fbd89d Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B04_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B05_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B05_10m.tif new file mode 100644 index 00000000000..a73de74ec33 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B05_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B06_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B06_10m.tif new file mode 100644 index 00000000000..65d8ef98d17 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B06_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B07_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B07_10m.tif new file mode 100644 index 00000000000..558bbd08853 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B07_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B08_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B08_10m.tif new file mode 100644 index 00000000000..532a7d37cf6 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B08_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B09_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B09_10m.tif new file mode 100644 index 00000000000..7111bb5dbba Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B09_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B11_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B11_10m.tif new file mode 100644 index 00000000000..68106c1669b Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B11_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B12_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B12_10m.tif new file mode 100644 index 00000000000..4ea3767ce4c Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B12_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B8A_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B8A_10m.tif new file mode 100644 index 00000000000..6f7df54f0b3 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B8A_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_CLD_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_CLD_10m.tif new file mode 100644 index 00000000000..41f05d9ba2f Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_CLD_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/labels/00.geojson b/tests/data/technoserve-cashew-benin/labels/00.geojson new file mode 100644 index 00000000000..ba92dace006 --- /dev/null +++ b/tests/data/technoserve-cashew-benin/labels/00.geojson @@ -0,0 +1,8 @@ +{ +"type": "FeatureCollection", +"name": "cashew_benin", +"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:EPSG::32631" } }, +"features": [ +{ "type": "Feature", "properties": { "OBJECTID": 1, "class": 1, "Shape_Leng": 367629.52331100003, "Shape_Area": 16997542.377500001, "class_name": "Well-managed planatation" }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 447131.214800000190735, 1001286.6359 ], [ 447166.272199999541044, 1001285.31299999915 ], [ 447196.037899999879301, 1001285.31299999915 ], [ 447244.324400000274181, 1001283.3286 ], [ 447256.230700000189245, 1001282.00569999963 ], [ 447253.58490000013262, 1001248.9327 ], [ 447254.907800000160933, 1001228.4275 ], [ 447252.923700000159442, 1001212.05179999955 ], [ 447087.555899999978, 1001212.333799999207 ], [ 447082.266800000332296, 1001241.656599999988 ], [ 447076.97510000038892, 1001256.208799999207 ], [ 447074.329300000332296, 1001286.6359 ], [ 447131.214800000190735, 1001286.6359 ] ] ] } } +] +} diff --git a/tests/data/ts_cashew_benin/ts_cashew_benin_labels.tar.gz b/tests/data/ts_cashew_benin/ts_cashew_benin_labels.tar.gz deleted file mode 100644 index 5a9d7d22a18..00000000000 Binary files a/tests/data/ts_cashew_benin/ts_cashew_benin_labels.tar.gz and /dev/null differ diff --git a/tests/data/ts_cashew_benin/ts_cashew_benin_source.tar.gz b/tests/data/ts_cashew_benin/ts_cashew_benin_source.tar.gz deleted file mode 100644 index 1e94b5526b0..00000000000 Binary files a/tests/data/ts_cashew_benin/ts_cashew_benin_source.tar.gz and /dev/null differ diff --git a/tests/datasets/test_benin_cashews.py b/tests/datasets/test_benin_cashews.py index 1e960527a84..cc06c0060a0 100644 --- a/tests/datasets/test_benin_cashews.py +++ b/tests/datasets/test_benin_cashews.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -18,44 +16,22 @@ DatasetNotFoundError, RGBBandsMissingError, ) - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join('tests', 'data', 'ts_cashew_benin', '*.tar.gz') - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(dataset_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestBeninSmallHolderCashews: @pytest.fixture def dataset( - self, monkeypatch: MonkeyPatch, tmp_path: Path + self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path ) -> BeninSmallHolderCashews: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - source_md5 = '255efff0f03bc6322470949a09bc76db' - labels_md5 = 'ed2195d93ca6822d48eb02bc3e81c127' - monkeypatch.setitem(BeninSmallHolderCashews.image_meta, 'md5', source_md5) - monkeypatch.setitem(BeninSmallHolderCashews.target_meta, 'md5', labels_md5) - monkeypatch.setattr(BeninSmallHolderCashews, 'dates', ('2019_11_05',)) + url = os.path.join('tests', 'data', 'technoserve-cashew-benin') + monkeypatch.setattr(BeninSmallHolderCashews, 'url', url) + monkeypatch.setattr(BeninSmallHolderCashews, 'dates', ('20191105',)) + monkeypatch.setattr(BeninSmallHolderCashews, 'tile_height', 2) + monkeypatch.setattr(BeninSmallHolderCashews, 'tile_width', 2) root = str(tmp_path) transforms = nn.Identity() - bands = BeninSmallHolderCashews.all_bands - - return BeninSmallHolderCashews( - root, - transforms=transforms, - bands=bands, - download=True, - api_key='', - checksum=True, - verbose=True, - ) + return BeninSmallHolderCashews(root, transforms=transforms, download=True) def test_getitem(self, dataset: BeninSmallHolderCashews) -> None: x = dataset[0] @@ -66,15 +42,15 @@ def test_getitem(self, dataset: BeninSmallHolderCashews) -> None: assert isinstance(x['y'], torch.Tensor) def test_len(self, dataset: BeninSmallHolderCashews) -> None: - assert len(dataset) == 72 + assert len(dataset) == 1 def test_add(self, dataset: BeninSmallHolderCashews) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - assert len(ds) == 144 + assert len(ds) == 2 def test_already_downloaded(self, dataset: BeninSmallHolderCashews) -> None: - BeninSmallHolderCashews(root=dataset.root, download=True, api_key='') + BeninSmallHolderCashews(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): @@ -82,9 +58,6 @@ def test_not_downloaded(self, tmp_path: Path) -> None: def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): - BeninSmallHolderCashews(bands=['B01', 'B02']) # type: ignore[arg-type] - - with pytest.raises(ValueError, match='is an invalid band name.'): BeninSmallHolderCashews(bands=('foo', 'bar')) def test_plot(self, dataset: BeninSmallHolderCashews) -> None: diff --git a/tests/datasets/test_cloud_cover.py b/tests/datasets/test_cloud_cover.py index e1dc89483c4..dae87cf3633 100644 --- a/tests/datasets/test_cloud_cover.py +++ b/tests/datasets/test_cloud_cover.py @@ -1,15 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn +from _pytest.fixtures import SubRequest from pytest import MonkeyPatch from torchgeo.datasets import ( @@ -17,62 +16,30 @@ DatasetNotFoundError, RGBBandsMissingError, ) - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join( - 'tests', 'data', 'ref_cloud_cover_detection_challenge_v1', '*.tar.gz' - ) - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(dataset_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestCloudCoverDetection: - @pytest.fixture - def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CloudCoverDetection: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - - test_image_meta = { - 'filename': 'ref_cloud_cover_detection_challenge_v1_test_source.tar.gz', - 'md5': '542e64a6e39b53c84c6462ec1b989e43', - } - monkeypatch.setitem(CloudCoverDetection.image_meta, 'test', test_image_meta) - - test_target_meta = { - 'filename': 'ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz', - 'md5': 'e8d41de08744a9845e74fca1eee3d1d3', - } - monkeypatch.setitem(CloudCoverDetection.target_meta, 'test', test_target_meta) - + @pytest.fixture(params=['train', 'test']) + def dataset( + self, + azcopy: Executable, + monkeypatch: MonkeyPatch, + tmp_path: Path, + request: SubRequest, + ) -> CloudCoverDetection: + url = os.path.join('tests', 'data', 'ref_cloud_cover_detection_challenge_v1') + monkeypatch.setattr(CloudCoverDetection, 'url', url) root = str(tmp_path) - split = 'test' + split = request.param transforms = nn.Identity() - return CloudCoverDetection( - root=root, - transforms=transforms, - split=split, - download=True, - api_key='', - checksum=True, + root=root, split=split, transforms=transforms, download=True ) def test_invalid_band(self, dataset: CloudCoverDetection) -> None: - invalid_bands = ['B09'] - with pytest.raises(ValueError): - CloudCoverDetection( - root=dataset.root, - split='test', - download=False, - api_key='', - bands=invalid_bands, - ) + with pytest.raises(AssertionError): + CloudCoverDetection(root=dataset.root, split=dataset.split, bands=['B09']) def test_getitem(self, dataset: CloudCoverDetection) -> None: x = dataset[0] @@ -84,28 +51,23 @@ def test_len(self, dataset: CloudCoverDetection) -> None: assert len(dataset) == 1 def test_already_downloaded(self, dataset: CloudCoverDetection) -> None: - CloudCoverDetection(root=dataset.root, split='test', download=True, api_key='') + CloudCoverDetection(root=dataset.root, split=dataset.split, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): CloudCoverDetection(str(tmp_path)) def test_plot(self, dataset: CloudCoverDetection) -> None: - dataset.plot(dataset[0], suptitle='Test') - plt.close() - sample = dataset[0] + dataset.plot(sample, suptitle='Test') + plt.close() sample['prediction'] = sample['mask'].clone() dataset.plot(sample, suptitle='Pred') plt.close() def test_plot_rgb(self, dataset: CloudCoverDetection) -> None: dataset = CloudCoverDetection( - root=dataset.root, - split='test', - bands=list(['B08']), - download=True, - api_key='', + root=dataset.root, split=dataset.split, bands=['B08'], download=True ) with pytest.raises( RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' diff --git a/tests/datasets/test_cv4a_kenya_crop_type.py b/tests/datasets/test_cv4a_kenya_crop_type.py index ad0e26ed03d..34f67036d2a 100644 --- a/tests/datasets/test_cv4a_kenya_crop_type.py +++ b/tests/datasets/test_cv4a_kenya_crop_type.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -18,44 +16,23 @@ DatasetNotFoundError, RGBBandsMissingError, ) - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join( - 'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz' - ) - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(dataset_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestCV4AKenyaCropType: @pytest.fixture - def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - source_md5 = '7f4dcb3f33743dddd73f453176308bfb' - labels_md5 = '95fc59f1d94a85ec00931d4d1280bec9' - monkeypatch.setitem(CV4AKenyaCropType.image_meta, 'md5', source_md5) - monkeypatch.setitem(CV4AKenyaCropType.target_meta, 'md5', labels_md5) - monkeypatch.setattr( - CV4AKenyaCropType, 'tile_names', ['ref_african_crops_kenya_02_tile_00'] - ) + def dataset( + self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> CV4AKenyaCropType: + url = os.path.join('tests', 'data', 'cv4a_kenya_crop_type') + monkeypatch.setattr(CV4AKenyaCropType, 'url', url) + monkeypatch.setattr(CV4AKenyaCropType, 'tiles', list(map(str, range(1)))) monkeypatch.setattr(CV4AKenyaCropType, 'dates', ['20190606']) + monkeypatch.setattr(CV4AKenyaCropType, 'tile_height', 2) + monkeypatch.setattr(CV4AKenyaCropType, 'tile_width', 2) root = str(tmp_path) transforms = nn.Identity() - return CV4AKenyaCropType( - root, - transforms=transforms, - download=True, - api_key='', - checksum=True, - verbose=True, - ) + return CV4AKenyaCropType(root, transforms=transforms, download=True) def test_getitem(self, dataset: CV4AKenyaCropType) -> None: x = dataset[0] @@ -66,60 +43,34 @@ def test_getitem(self, dataset: CV4AKenyaCropType) -> None: assert isinstance(x['y'], torch.Tensor) def test_len(self, dataset: CV4AKenyaCropType) -> None: - assert len(dataset) == 345 + assert len(dataset) == 1 def test_add(self, dataset: CV4AKenyaCropType) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - assert len(ds) == 690 - - def test_get_splits(self, dataset: CV4AKenyaCropType) -> None: - train_field_ids, test_field_ids = dataset.get_splits() - assert isinstance(train_field_ids, list) - assert isinstance(test_field_ids, list) - assert len(train_field_ids) == 18 - assert len(test_field_ids) == 9 - assert 336 in train_field_ids - assert 336 not in test_field_ids - assert 4793 in test_field_ids - assert 4793 not in train_field_ids + assert len(ds) == 2 def test_already_downloaded(self, dataset: CV4AKenyaCropType) -> None: - CV4AKenyaCropType(root=dataset.root, download=True, api_key='') + CV4AKenyaCropType(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): CV4AKenyaCropType(str(tmp_path)) - def test_invalid_tile(self, dataset: CV4AKenyaCropType) -> None: - with pytest.raises(AssertionError): - dataset._load_label_tile('foo') - - with pytest.raises(AssertionError): - dataset._load_all_image_tiles('foo', ('B01', 'B02')) - - with pytest.raises(AssertionError): - dataset._load_single_image_tile('foo', '20190606', ('B01', 'B02')) - def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): - CV4AKenyaCropType(bands=['B01', 'B02']) # type: ignore[arg-type] - - with pytest.raises(ValueError, match='is an invalid band name.'): CV4AKenyaCropType(bands=('foo', 'bar')) def test_plot(self, dataset: CV4AKenyaCropType) -> None: - dataset.plot(dataset[0], time_step=0, suptitle='Test') - plt.close() - sample = dataset[0] + dataset.plot(sample, time_step=0, suptitle='Test') + plt.close() sample['prediction'] = sample['mask'].clone() dataset.plot(sample, time_step=0, suptitle='Pred') plt.close() def test_plot_rgb(self, dataset: CV4AKenyaCropType) -> None: dataset = CV4AKenyaCropType(root=dataset.root, bands=tuple(['B01'])) - with pytest.raises( - RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' - ): - dataset.plot(dataset[0], time_step=0, suptitle='Single Band') + match = 'Dataset does not contain some of the RGB bands' + with pytest.raises(RGBBandsMissingError, match=match): + dataset.plot(dataset[0]) diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index d165b064a90..bb18bed06ca 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -15,52 +13,33 @@ from torch.utils.data import ConcatDataset from torchgeo.datasets import DatasetNotFoundError, TropicalCyclone - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - for tarball in glob.iglob(os.path.join('tests', 'data', 'cyclone', '*.tar.gz')): - shutil.copy(tarball, output_dir) - - -def fetch(collection_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestTropicalCyclone: @pytest.fixture(params=['train', 'test']) def dataset( - self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + self, + request: SubRequest, + azcopy: Executable, + monkeypatch: MonkeyPatch, + tmp_path: Path, ) -> TropicalCyclone: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - md5s = { - 'train': { - 'source': '2b818e0a0873728dabf52c7054a0ce4c', - 'labels': 'c3c2b6d02c469c5519f4add4f9132712', - }, - 'test': { - 'source': 'bc07c519ddf3ce88857435ddddf98a16', - 'labels': '3ca4243eff39b87c73e05ec8db1824bf', - }, - } - monkeypatch.setattr(TropicalCyclone, 'md5s', md5s) - monkeypatch.setattr(TropicalCyclone, 'size', 1) + url = os.path.join('tests', 'data', 'cyclone') + monkeypatch.setattr(TropicalCyclone, 'url', url) + monkeypatch.setattr(TropicalCyclone, 'size', 2) root = str(tmp_path) split = request.param transforms = nn.Identity() - return TropicalCyclone( - root, split, transforms, download=True, api_key='', checksum=True - ) + return TropicalCyclone(root, split, transforms, download=True) @pytest.mark.parametrize('index', [0, 1]) def test_getitem(self, dataset: TropicalCyclone, index: int) -> None: x = dataset[index] assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) - assert isinstance(x['storm_id'], str) - assert isinstance(x['relative_time'], int) - assert isinstance(x['ocean'], int) + assert isinstance(x['relative_time'], torch.Tensor) + assert isinstance(x['ocean'], torch.Tensor) assert isinstance(x['label'], torch.Tensor) assert x['image'].shape == (3, dataset.size, dataset.size) @@ -73,7 +52,7 @@ def test_add(self, dataset: TropicalCyclone) -> None: assert len(ds) == 10 def test_already_downloaded(self, dataset: TropicalCyclone) -> None: - TropicalCyclone(root=dataset.root, download=True, api_key='') + TropicalCyclone(root=dataset.root, download=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -84,10 +63,9 @@ def test_not_downloaded(self, tmp_path: Path) -> None: TropicalCyclone(str(tmp_path)) def test_plot(self, dataset: TropicalCyclone) -> None: - dataset.plot(dataset[0], suptitle='Test') - plt.close() - sample = dataset[0] + dataset.plot(sample, suptitle='Test') + plt.close() sample['prediction'] = sample['label'] dataset.plot(sample) plt.close() diff --git a/tests/datasets/test_rwanda_field_boundary.py b/tests/datasets/test_rwanda_field_boundary.py index 6f83b12a93d..ddf5b5df7fb 100644 --- a/tests/datasets/test_rwanda_field_boundary.py +++ b/tests/datasets/test_rwanda_field_boundary.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -19,45 +17,26 @@ RGBBandsMissingError, RwandaFieldBoundary, ) - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join('tests', 'data', 'rwanda_field_boundary', '*.tar.gz') - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(dataset_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestRwandaFieldBoundary: @pytest.fixture(params=['train', 'test']) def dataset( - self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + self, + azcopy: Executable, + monkeypatch: MonkeyPatch, + tmp_path: Path, + request: SubRequest, ) -> RwandaFieldBoundary: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - monkeypatch.setattr( - RwandaFieldBoundary, 'number_of_patches_per_split', {'train': 5, 'test': 5} - ) - monkeypatch.setattr( - RwandaFieldBoundary, - 'md5s', - { - 'train_images': 'af9395e2e49deefebb35fa65fa378ba3', - 'test_images': 'd104bb82323a39e7c3b3b7dd0156f550', - 'train_labels': '6cceaf16a141cf73179253a783e7d51b', - }, - ) + url = os.path.join('tests', 'data', 'rwanda_field_boundary') + monkeypatch.setattr(RwandaFieldBoundary, 'url', url) + monkeypatch.setattr(RwandaFieldBoundary, 'splits', {'train': 1, 'test': 1}) root = str(tmp_path) split = request.param transforms = nn.Identity() - return RwandaFieldBoundary( - root, split, transforms=transforms, api_key='', download=True, checksum=True - ) + return RwandaFieldBoundary(root, split, transforms=transforms, download=True) def test_getitem(self, dataset: RwandaFieldBoundary) -> None: x = dataset[0] @@ -69,23 +48,12 @@ def test_getitem(self, dataset: RwandaFieldBoundary) -> None: assert 'mask' not in x def test_len(self, dataset: RwandaFieldBoundary) -> None: - assert len(dataset) == 5 + assert len(dataset) == 1 def test_add(self, dataset: RwandaFieldBoundary) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - assert len(ds) == 10 - - def test_needs_extraction(self, tmp_path: Path) -> None: - root = str(tmp_path) - for fn in [ - 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', - 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', - 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', - ]: - url = os.path.join('tests', 'data', 'rwanda_field_boundary', fn) - shutil.copy(url, root) - RwandaFieldBoundary(root, checksum=False) + assert len(ds) == 2 def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None: RwandaFieldBoundary(root=dataset.root) @@ -94,35 +62,8 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): RwandaFieldBoundary(str(tmp_path)) - def test_corrupted(self, tmp_path: Path) -> None: - for fn in [ - 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', - 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', - 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', - ]: - with open(os.path.join(tmp_path, fn), 'w') as f: - f.write('bad') - with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - RwandaFieldBoundary(root=str(tmp_path), checksum=True) - - def test_failed_download(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - monkeypatch.setattr( - RwandaFieldBoundary, - 'md5s', - {'train_images': 'bad', 'test_images': 'bad', 'train_labels': 'bad'}, - ) - root = str(tmp_path) - with pytest.raises(RuntimeError, match='Dataset not found or corrupted.'): - RwandaFieldBoundary(root, 'train', api_key='', download=True, checksum=True) - - def test_no_api_key(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match='Must provide an API key to download'): - RwandaFieldBoundary(str(tmp_path), api_key=None, download=True) - def test_invalid_bands(self) -> None: - with pytest.raises(ValueError, match='is an invalid band name.'): + with pytest.raises(AssertionError): RwandaFieldBoundary(bands=('foo', 'bar')) def test_plot(self, dataset: RwandaFieldBoundary) -> None: diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index c53bfbed0fc..d6c9bc15c38 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -597,7 +597,7 @@ def test_lazy_import_missing(name: str) -> None: def test_azcopy(tmp_path: Path, azcopy: Executable) -> None: source = os.path.join('tests', 'data', 'cyclone') azcopy('sync', source, tmp_path, '--recursive=true') - assert os.path.exists(tmp_path / 'nasa_tropical_storm_competition_test_labels') + assert os.path.exists(tmp_path / 'test') def test_which() -> None: diff --git a/tests/transforms/test_color.py b/tests/transforms/test_color.py index 2e271f89bc9..2cea90b396d 100644 --- a/tests/transforms/test_color.py +++ b/tests/transforms/test_color.py @@ -1,11 +1,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import kornia.augmentation as K import pytest import torch from torch import Tensor -from torchgeo.transforms import AugmentationSequential, RandomGrayscale +from torchgeo.transforms import RandomGrayscale @pytest.fixture @@ -33,12 +34,15 @@ def batch() -> dict[str, Tensor]: ], ) def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> None: - aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image']) + aug = K.AugmentationSequential( + RandomGrayscale(weights, p=1), keepdim=True, data_keys=None + ) + # https://github.com/kornia/kornia/issues/2848 + aug.keepdim = True output = aug(sample) assert output['image'].shape == sample['image'].shape - assert output['image'].sum() == sample['image'].sum() for i in range(1, 3): - assert torch.allclose(output['image'][0, 0], output['image'][0, i]) + assert torch.allclose(output['image'][0], output['image'][i]) @pytest.mark.parametrize( @@ -50,9 +54,8 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> ], ) def test_random_grayscale_batch(weights: Tensor, batch: dict[str, Tensor]) -> None: - aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image']) + aug = K.AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=None) output = aug(batch) assert output['image'].shape == batch['image'].shape - assert output['image'].sum() == batch['image'].sum() for i in range(1, 3): assert torch.allclose(output['image'][0, 0], output['image'][0, i]) diff --git a/tests/transforms/test_indices.py b/tests/transforms/test_indices.py index 3d83f857304..9e6f54e48c4 100644 --- a/tests/transforms/test_indices.py +++ b/tests/transforms/test_indices.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import kornia.augmentation as K import pytest import torch from torch import Tensor @@ -20,7 +21,6 @@ AppendRBNDVI, AppendSWI, AppendTriBandNormalizedDifferenceIndex, - AugmentationSequential, ) @@ -42,9 +42,8 @@ def batch() -> dict[str, Tensor]: def test_append_index_sample(sample: dict[str, Tensor]) -> None: c, h, w = sample['image'].shape - aug = AugmentationSequential( - AppendNormalizedDifferenceIndex(index_a=0, index_b=1), - data_keys=['image', 'mask'], + aug = K.AugmentationSequential( + AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None ) output = aug(sample) assert output['image'].shape == (1, c + 1, h, w) @@ -52,9 +51,8 @@ def test_append_index_sample(sample: dict[str, Tensor]) -> None: def test_append_index_batch(batch: dict[str, Tensor]) -> None: b, c, h, w = batch['image'].shape - aug = AugmentationSequential( - AppendNormalizedDifferenceIndex(index_a=0, index_b=1), - data_keys=['image', 'mask'], + aug = K.AugmentationSequential( + AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None ) output = aug(batch) assert output['image'].shape == (b, c + 1, h, w) @@ -62,9 +60,9 @@ def test_append_index_batch(batch: dict[str, Tensor]) -> None: def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None: b, c, h, w = batch['image'].shape - aug = AugmentationSequential( + aug = K.AugmentationSequential( AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=1, index_c=2), - data_keys=['image', 'mask'], + data_keys=None, ) output = aug(batch) assert output['image'].shape == (b, c + 1, h, w) @@ -88,7 +86,7 @@ def test_append_normalized_difference_indices( sample: dict[str, Tensor], index: AppendNormalizedDifferenceIndex ) -> None: c, h, w = sample['image'].shape - aug = AugmentationSequential(index(0, 1), data_keys=['image', 'mask']) + aug = K.AugmentationSequential(index(0, 1), data_keys=None) output = aug(sample) assert output['image'].shape == (1, c + 1, h, w) @@ -98,6 +96,6 @@ def test_append_tri_band_normalized_difference_indices( sample: dict[str, Tensor], index: AppendTriBandNormalizedDifferenceIndex ) -> None: c, h, w = sample['image'].shape - aug = AugmentationSequential(index(0, 1, 2), data_keys=['image', 'mask']) + aug = K.AugmentationSequential(index(0, 1, 2), data_keys=None) output = aug(sample) assert output['image'].shape == (1, c + 1, h, w) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 39021fc2acc..e9af302094b 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -43,18 +43,11 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ['fit', 'validate']: - self.dataset = TropicalCyclone(split='train', **self.kwargs) - - storm_ids = [] - for item in self.dataset.collection: - storm_id = item['href'].split('/')[0].split('_')[-2] - storm_ids.append(storm_id) - + dataset = TropicalCyclone(split='train', **self.kwargs) train_indices, val_indices = group_shuffle_split( - storm_ids, test_size=0.2, random_state=0 + dataset.features['Storm ID'], test_size=0.2, random_state=0 ) - - self.train_dataset = Subset(self.dataset, train_indices) - self.val_dataset = Subset(self.dataset, val_indices) + self.train_dataset = Subset(dataset, train_indices) + self.val_dataset = Subset(dataset, val_indices) if stage in ['test']: self.test_dataset = TropicalCyclone(split='test', **self.kwargs) diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py index 8be13a170cf..fd325aaa6f8 100644 --- a/torchgeo/datasets/agrifieldnet.py +++ b/torchgeo/datasets/agrifieldnet.py @@ -51,20 +51,20 @@ class AgriFieldNet(RasterDataset): Dataset classes: - 0 - No-Data - 1 - Wheat - 2 - Mustard - 3 - Lentil - 4 - No Crop/Fallow - 5 - Green pea - 6 - Sugarcane - 8 - Garlic - 9 - Maize - 13 - Gram - 14 - Coriander - 15 - Potato - 16 - Berseem - 36 - Rice + * 0. No-Data + * 1. Wheat + * 2. Mustard + * 3. Lentil + * 4. No Crop/Fallow + * 5. Green pea + * 6. Sugarcane + * 8. Garlic + * 9. Maize + * 13. Gram + * 14. Coriander + * 15. Potato + * 16. Berseem + * 36. Rice If you use this dataset in your research, please cite the following dataset: diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 6682f42b7b7..686d5974324 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -5,7 +5,7 @@ import json import os -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import lru_cache import matplotlib.pyplot as plt @@ -19,10 +19,9 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import which -# TODO: read geospatial information from stac.json files class BeninSmallHolderCashews(NonGeoDataset): r"""Smallholder Cashew Plantations in Benin dataset. @@ -30,8 +29,8 @@ class BeninSmallHolderCashews(NonGeoDataset): in the center of Benin. Each pixel is classified for Well-managed plantation, Poorly-managed plantation, No plantation and other classes. The labels are generated using a combination of ground data collection with a handheld GPS device, - and final corrections based on Airbus Pléiades imagery. See - `this website `__ for dataset details. + and final corrections based on Airbus Pléiades imagery. See `this website + `__ for dataset details. Specifically, the data consists of Sentinel 2 imagery from a 120 km\ :sup:`2`\ area in the center of Benin over 71 points in time from 11/05/2019 to 10/30/2020 @@ -47,97 +46,88 @@ class BeninSmallHolderCashews(NonGeoDataset): If you use this dataset in your research, please cite the following: - * https://doi.org/10.34911/rdnt.hfv20i + * https://beta.source.coop/technoserve/cashews-benin/ .. note:: This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. """ - dataset_id = 'ts_cashew_benin' - collection_ids = ['ts_cashew_benin_source', 'ts_cashew_benin_labels'] - image_meta = { - 'filename': 'ts_cashew_benin_source.tar.gz', - 'md5': '957272c86e518a925a4e0d90dab4f92d', - } - target_meta = { - 'filename': 'ts_cashew_benin_labels.tar.gz', - 'md5': 'f9d3f0c671427d852fae9b52a0ae0051', - } + url = 'https://radiantearth.blob.core.windows.net/mlhub/technoserve-cashew-benin' dates = ( - '2019_11_05', - '2019_11_10', - '2019_11_15', - '2019_11_20', - '2019_11_30', - '2019_12_05', - '2019_12_10', - '2019_12_15', - '2019_12_20', - '2019_12_25', - '2019_12_30', - '2020_01_04', - '2020_01_09', - '2020_01_14', - '2020_01_19', - '2020_01_24', - '2020_01_29', - '2020_02_08', - '2020_02_13', - '2020_02_18', - '2020_02_23', - '2020_02_28', - '2020_03_04', - '2020_03_09', - '2020_03_14', - '2020_03_19', - '2020_03_24', - '2020_03_29', - '2020_04_03', - '2020_04_08', - '2020_04_13', - '2020_04_18', - '2020_04_23', - '2020_04_28', - '2020_05_03', - '2020_05_08', - '2020_05_13', - '2020_05_18', - '2020_05_23', - '2020_05_28', - '2020_06_02', - '2020_06_07', - '2020_06_12', - '2020_06_17', - '2020_06_22', - '2020_06_27', - '2020_07_02', - '2020_07_07', - '2020_07_12', - '2020_07_17', - '2020_07_22', - '2020_07_27', - '2020_08_01', - '2020_08_06', - '2020_08_11', - '2020_08_16', - '2020_08_21', - '2020_08_26', - '2020_08_31', - '2020_09_05', - '2020_09_10', - '2020_09_15', - '2020_09_20', - '2020_09_25', - '2020_09_30', - '2020_10_10', - '2020_10_15', - '2020_10_20', - '2020_10_25', - '2020_10_30', + '20191105', + '20191110', + '20191115', + '20191120', + '20191130', + '20191205', + '20191210', + '20191215', + '20191220', + '20191225', + '20191230', + '20200104', + '20200109', + '20200114', + '20200119', + '20200124', + '20200129', + '20200208', + '20200213', + '20200218', + '20200223', + '20200228', + '20200304', + '20200309', + '20200314', + '20200319', + '20200324', + '20200329', + '20200403', + '20200408', + '20200413', + '20200418', + '20200423', + '20200428', + '20200503', + '20200508', + '20200513', + '20200518', + '20200523', + '20200528', + '20200602', + '20200607', + '20200612', + '20200617', + '20200622', + '20200627', + '20200702', + '20200707', + '20200712', + '20200717', + '20200722', + '20200727', + '20200801', + '20200806', + '20200811', + '20200816', + '20200821', + '20200826', + '20200831', + '20200905', + '20200910', + '20200915', + '20200920', + '20200925', + '20200930', + '20201010', + '20201015', + '20201020', + '20201025', + '20201030', ) all_bands = ( @@ -176,12 +166,9 @@ def __init__( root: str = 'data', chip_size: int = 256, stride: int = 128, - bands: tuple[str, ...] = all_bands, + bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, - verbose: bool = False, ) -> None: """Initialize a new Benin Smallholder Cashew Plantations Dataset instance. @@ -194,28 +181,21 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - verbose: if True, print messages when new tiles are loaded Raises: + AssertionError: If *bands* is invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ - self._validate_bands(bands) + assert set(bands) <= set(self.all_bands) self.root = root self.chip_size = chip_size self.stride = stride self.bands = bands self.transforms = transforms - self.checksum = checksum - self.verbose = verbose + self.download = download - if download: - self._download(api_key) - - if not self._check_integrity(): - raise DatasetNotFoundError(self) + self._verify() # Calculate the indices that we will use over all tiles self.chips_metadata = [] @@ -238,7 +218,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ y, x = self.chips_metadata[index] - img, transform, crs = self._load_all_imagery(self.bands) + img, transform, crs = self._load_all_imagery() labels = self._load_mask(transform) img = img[:, :, y : y + self.chip_size, x : x + self.chip_size] @@ -266,92 +246,55 @@ def __len__(self) -> int: """ return len(self.chips_metadata) - def _validate_bands(self, bands: tuple[str, ...]) -> None: - """Validate list of bands. - - Args: - bands: user-provided tuple of bands to load - - Raises: - AssertionError: if ``bands`` is not a tuple - ValueError: if an invalid band name is provided - """ - assert isinstance(bands, tuple), 'The list of bands must be a tuple' - for band in bands: - if band not in self.all_bands: - raise ValueError(f"'{band}' is an invalid band name.") - @lru_cache(maxsize=128) - def _load_all_imagery( - self, bands: tuple[str, ...] = all_bands - ) -> tuple[Tensor, rasterio.Affine, CRS]: + def _load_all_imagery(self) -> tuple[Tensor, rasterio.Affine, CRS]: """Load all the imagery (across time) for the dataset. - Optionally allows for subsetting of the bands that are loaded. - - Args: - bands: tuple of bands to load - Returns: imagery of shape (70, number of bands, 1186, 1122) where 70 is the number of points in time, 1186 is the tile height, and 1122 is the tile width rasterio affine transform, mapping pixel coordinates to geo coordinates coordinate reference system of transform """ - if self.verbose: - print('Loading all imagery') - img = torch.zeros( len(self.dates), - len(bands), + len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32, ) for date_index, date in enumerate(self.dates): - single_scene, transform, crs = self._load_single_scene(date, self.bands) + single_scene, transform, crs = self._load_single_scene(date) img[date_index] = single_scene return img, transform, crs @lru_cache(maxsize=128) - def _load_single_scene( - self, date: str, bands: tuple[str, ...] - ) -> tuple[Tensor, rasterio.Affine, CRS]: + def _load_single_scene(self, date: str) -> tuple[Tensor, rasterio.Affine, CRS]: """Load the imagery for a single date. - Optionally allows for subsetting of the bands that are loaded. - Args: date: date of the imagery to load - bands: bands to load Returns: Tensor containing a single image tile, rasterio affine transform, mapping pixel coordinates to geo coordinates, and coordinate reference system of transform. - - Raises: - AssertionError: if ``date`` is invalid """ - assert date in self.dates - - if self.verbose: - print(f'Loading imagery at {date}') - img = torch.zeros( - len(bands), self.tile_height, self.tile_width, dtype=torch.float32 + len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32 ) for band_index, band_name in enumerate(self.bands): filepath = os.path.join( self.root, - 'ts_cashew_benin_source', - f'ts_cashew_benin_source_00_{date}', - f'{band_name}.tif', + 'imagery', + '00', + f'00_{date}', + f'00_{date}_{band_name}_10m.tif', ) with rasterio.open(filepath) as src: - transform = src.transform # same transform for every bands + transform = src.transform # same transform for every band crs = src.crs array = src.read().astype(np.float32) img[band_index] = torch.from_numpy(array) @@ -362,10 +305,7 @@ def _load_single_scene( def _load_mask(self, transform: rasterio.Affine) -> Tensor: """Rasterizes the dataset's labels (in geojson format).""" # Create a mask layer out of the geojson - mask_geojson_fn = os.path.join( - self.root, 'ts_cashew_benin_labels', '_common', 'labels.geojson' - ) - with open(mask_geojson_fn) as f: + with open(os.path.join(self.root, 'labels', '00.geojson')) as f: geojson = json.load(f) labels = [ @@ -385,44 +325,24 @@ def _load_mask(self, transform: rasterio.Affine) -> Tensor: mask = torch.from_numpy(mask_data).long() return mask - def _check_integrity(self) -> bool: - """Check integrity of dataset. - - Returns: - True if dataset files are found and/or MD5s match, else False - """ - images: bool = check_integrity( - os.path.join(self.root, self.image_meta['filename']), - self.image_meta['md5'] if self.checksum else None, - ) - - targets: bool = check_integrity( - os.path.join(self.root, self.target_meta['filename']), - self.target_meta['md5'] if self.checksum else None, - ) - - return images and targets - - def _download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. - - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match - """ - if self._check_integrity(): - print('Files already downloaded and verified') + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + if os.path.exists(os.path.join(self.root, 'labels', '00.geojson')): return - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, api_key) + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() - image_archive_path = os.path.join(self.root, self.image_meta['filename']) - target_archive_path = os.path.join(self.root, self.target_meta['filename']) - for fn in [image_archive_path, target_archive_path]: - extract_archive(fn, self.root) + def _download(self) -> None: + """Download the dataset.""" + os.makedirs(self.root, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', self.url, self.root, '--recursive=true') def plot( self, @@ -454,9 +374,6 @@ def plot( else: raise RGBBandsMissingError() - num_time_points = sample['image'].shape[0] - assert time_step < num_time_points - image = np.rollaxis(sample['image'][time_step, rgb_indices].numpy(), 0, 3) image = np.clip(image / 3000, 0, 1) mask = sample['mask'].numpy() diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index bb975c8002b..ab440648a17 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -196,7 +196,7 @@ def _load_target(self, filename: str) -> Tensor: target mask """ with rasterio.open(os.path.join(self.root, 'train_agbm', filename), 'r') as src: - arr: np.typing.NDArray[np.float_] = src.read() + arr: np.typing.NDArray[np.float64] = src.read() target = torch.from_numpy(arr).float() return target diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py index 2aeea01c568..552693684ba 100644 --- a/torchgeo/datasets/cloud_cover.py +++ b/torchgeo/datasets/cloud_cover.py @@ -3,13 +3,12 @@ """Cloud Cover Detection Challenge dataset.""" -import json import os from collections.abc import Callable, Sequence -from typing import Any import matplotlib.pyplot as plt import numpy as np +import pandas as pd import rasterio import torch from matplotlib.figure import Figure @@ -17,18 +16,16 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import which -# TODO: read geospatial information from stac.json files class CloudCoverDetection(NonGeoDataset): - """Cloud Cover Detection Challenge dataset. + """Sentinel-2 Cloud Cover Segmentation Dataset. - This training dataset was generated as part of a - `crowdsourcing competition + This training dataset was generated as part of a `crowdsourcing competition `_ on DrivenData.org, and - later on was validated using a team of expert annotators. See - `this website `__ + later on was validated using a team of expert annotators. See `this website + `__ for dataset details. The dataset consists of Sentinel-2 satellite imagery and corresponding cloudy @@ -51,96 +48,52 @@ class CloudCoverDetection(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. .. versionadded:: 0.4 """ - collection_ids = [ - 'ref_cloud_cover_detection_challenge_v1_train_source', - 'ref_cloud_cover_detection_challenge_v1_train_labels', - 'ref_cloud_cover_detection_challenge_v1_test_source', - 'ref_cloud_cover_detection_challenge_v1_test_labels', - ] - - image_meta = { - 'train': { - 'filename': 'ref_cloud_cover_detection_challenge_v1_train_source.tar.gz', - 'md5': '32cfe38e313bcedc09dca3f0f9575eea', - }, - 'test': { - 'filename': 'ref_cloud_cover_detection_challenge_v1_test_source.tar.gz', - 'md5': '6c67edae18716598d47298f24992db6c', - }, - } - - target_meta = { - 'train': { - 'filename': 'ref_cloud_cover_detection_challenge_v1_train_labels.tar.gz', - 'md5': '695dfb1034924c10fbb17f9293815671', - }, - 'test': { - 'filename': 'ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz', - 'md5': 'ec2b42bb43e9a03a01ae096f9e09db9c', - }, - } - - collection_names = { - 'train': [ - 'ref_cloud_cover_detection_challenge_v1_train_source', - 'ref_cloud_cover_detection_challenge_v1_train_labels', - ], - 'test': [ - 'ref_cloud_cover_detection_challenge_v1_test_source', - 'ref_cloud_cover_detection_challenge_v1_test_labels', - ], - } - - band_names = ['B02', 'B03', 'B04', 'B08'] - + url = 'https://radiantearth.blob.core.windows.net/mlhub/ref_cloud_cover_detection_challenge_v1/final' + all_bands = ['B02', 'B03', 'B04', 'B08'] rgb_bands = ['B04', 'B03', 'B02'] + splits = {'train': 'public', 'test': 'private'} def __init__( self, root: str = 'data', split: str = 'train', - bands: Sequence[str] = band_names, + bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, ) -> None: - """Initiatlize a new Cloud Cover Detection Dataset instance. + """Initiatlize a CloudCoverDetection instance. Args: root: root directory where dataset can be found - split: train/val/test split to load + split: 'train' or 'test' bands: the subset of bands to load transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: + AssertionError: If *split* or *bands* are invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ + assert split in self.splits + assert set(bands) <= set(self.all_bands) + self.root = root self.split = split - self.transforms = transforms - self.checksum = checksum - - self._validate_bands(bands) self.bands = bands + self.transforms = transforms + self.download = download - if download: - self._download(api_key) - - if not self._check_integrity(): - raise DatasetNotFoundError(self) + self.csv = os.path.join(self.root, self.split, f'{self.split}_metadata.csv') + self._verify() - self.chip_paths = self._load_collections() + self.metadata = pd.read_csv(self.csv) def __len__(self) -> int: """Return the number of items in the dataset. @@ -148,7 +101,7 @@ def __len__(self) -> int: Returns: length of dataset in integer """ - return len(self.chip_paths) + return len(self.metadata) def __getitem__(self, index: int) -> dict[str, Tensor]: """Returns a sample from dataset. @@ -159,192 +112,65 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and label at given index """ - image = self._load_image(index) - label = self._load_target(index) - sample: dict[str, Tensor] = {'image': image, 'mask': label} + chip_id = self.metadata.iat[index, 0] + image = self._load_image(chip_id) + label = self._load_target(chip_id) + sample = {'image': image, 'mask': label} if self.transforms is not None: sample = self.transforms(sample) return sample - def _load_image(self, index: int) -> Tensor: + def _load_image(self, chip_id: str) -> Tensor: """Load all source images for a chip. Args: - index: position of the indexed chip + chip_id: ID of the chip. Returns: a tensor of stacked source image data """ - source_asset_paths = self.chip_paths[index]['source'] + path = os.path.join(self.root, self.split, f'{self.split}_features', chip_id) images = [] - for path in source_asset_paths: - with rasterio.open(path) as image_data: - image_array = image_data.read(1).astype(np.int32) - images.append(image_array) - image_stack: np.typing.NDArray[np.int_] = np.stack(images, axis=0) - image_tensor = torch.from_numpy(image_stack) - return image_tensor - - def _load_target(self, index: int) -> Tensor: + for band in self.bands: + with rasterio.open(os.path.join(path, f'{band}.tif')) as src: + images.append(src.read(1).astype(np.float32)) + return torch.from_numpy(np.stack(images, axis=0)) + + def _load_target(self, chip_id: str) -> Tensor: """Load label image for a chip. Args: - index: position of the indexed chip + chip_id: ID of the chip. Returns: a tensor of the label image data """ - label_asset_path = self.chip_paths[index]['target'][0] - with rasterio.open(label_asset_path) as target_data: - target_img = target_data.read(1).astype(np.int32) - - target_array: np.typing.NDArray[np.int_] = np.array(target_img) - target_tensor = torch.from_numpy(target_array) - return target_tensor - - @staticmethod - def _read_json_data(object_path: str) -> Any: - """Loads a JSON file. - - Args: - object_path: string path to the JSON file - - Returns: - json_data: JSON object / dictionary - - """ - with open(object_path) as read_contents: - json_data = json.load(read_contents) - return json_data - - def _load_items(self, item_json: str) -> dict[str, list[str]]: - """Loads the label item and corresponding source items. - - Args: - item_json: a string path to the item JSON file on disk - - Returns: - a dictionary with paths to the source and target TIF filenames - """ - item_meta = {} - - label_data = self._read_json_data(item_json) - label_asset_path = os.path.join( - os.path.split(item_json)[0], label_data['assets']['labels']['href'] - ) - item_meta['target'] = [label_asset_path] - - source_item_hrefs = [] - for link in label_data['links']: - if link['rel'] == 'source': - source_item_hrefs.append( - os.path.join(self.root, link['href'].replace('../../', '')) - ) - - source_item_hrefs = sorted(source_item_hrefs) - source_item_paths = [] - - for item_href in source_item_hrefs: - source_item_path = os.path.split(item_href)[0] - source_data = self._read_json_data(item_href) - source_item_assets = [] - for asset_key, asset_value in source_data['assets'].items(): - if asset_key in self.bands: - source_item_assets.append( - os.path.join(source_item_path, asset_value['href']) - ) - source_item_assets = sorted(source_item_assets) - for source_item_asset in source_item_assets: - source_item_paths.append(source_item_asset) - - item_meta['source'] = source_item_paths - return item_meta - - def _load_collections(self) -> list[dict[str, Any]]: - """Loads the paths to source and label assets for each collection. - - Returns: - a dictionary with lists of filepaths to all assets for each chip/item - - Raises: - RuntimeError if collection.json is not found in the uncompressed dataset - """ - indexed_chips = [] - label_collection: list[str] = [] - for c in self.collection_names[self.split]: - if 'label' in c: - label_collection.append(c) - label_collection_path = os.path.join(self.root, label_collection[0]) - label_collection_json = os.path.join(label_collection_path, 'collection.json') - - label_collection_item_hrefs = [] - for link in self._read_json_data(label_collection_json)['links']: - if link['rel'] == 'item': - label_collection_item_hrefs.append(link['href']) - - label_collection_item_hrefs = sorted(label_collection_item_hrefs) - - for label_href in label_collection_item_hrefs: - label_json = os.path.join(label_collection_path, label_href) - indexed_item = self._load_items(label_json) - indexed_chips.append(indexed_item) - - return indexed_chips - - def _validate_bands(self, bands: Sequence[str]) -> None: - """Validate list of bands. - - Args: - bands: user-provided tuple of bands to load - - Raises: - ValueError: if an invalid band name is provided - """ - for band in bands: - if band not in self.band_names: - raise ValueError(f"'{band}' is an invalid band name.") - - def _check_integrity(self) -> bool: - """Check integrity of dataset. - - Returns: - True if dataset files are found and/or MD5s match, else False - """ - images: bool = check_integrity( - os.path.join(self.root, self.image_meta[self.split]['filename']), - self.image_meta[self.split]['md5'] if self.checksum else None, - ) - - targets: bool = check_integrity( - os.path.join(self.root, self.target_meta[self.split]['filename']), - self.target_meta[self.split]['md5'] if self.checksum else None, - ) - - return images and targets + path = os.path.join(self.root, self.split, f'{self.split}_labels') + with rasterio.open(os.path.join(path, f'{chip_id}.tif')) as src: + return torch.from_numpy(src.read(1).astype(np.int64)) + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + if os.path.exists(self.csv): + return - def _download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - """ - if self._check_integrity(): - print('Files already downloaded and verified') - return + # Download the dataset + self._download() - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, api_key) - - image_archive_path = os.path.join( - self.root, self.image_meta[self.split]['filename'] - ) - target_archive_path = os.path.join( - self.root, self.target_meta[self.split]['filename'] - ) - for fn in [image_archive_path, target_archive_path]: - extract_archive(fn, self.root) + def _download(self) -> None: + """Download the dataset.""" + directory = os.path.join(self.root, self.split) + os.makedirs(directory, exist_ok=True) + url = f'{self.url}/{self.splits[self.split]}' + azcopy = which('azcopy') + azcopy('sync', url, directory, '--recursive=true') def plot( self, diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index a532c1539c4..feeb6ff0ec2 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -3,9 +3,8 @@ """CV4A Kenya Crop Type dataset.""" -import csv import os -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import lru_cache import matplotlib.pyplot as plt @@ -17,16 +16,23 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import which -# TODO: read geospatial information from stac.json files class CV4AKenyaCropType(NonGeoDataset): - """CV4A Kenya Crop Type dataset. + """CV4A Kenya Crop Type Competition dataset. - Used in a competition in the Computer NonGeo for Agriculture (CV4A) workshop in - ICLR 2020. See `this website `__ - for dataset details. + The `CV4A Kenya Crop Type Competition + `__ + dataset was produced as part of the Crop Type Detection competition at the + Computer Vision for Agriculture (CV4A) Workshop at the ICLR 2020 conference. + The objective of the competition was to create a machine learning model to + classify fields by crop type from images collected during the growing season + by the Sentinel-2 satellites. + + See the `dataset documentation + `__ + for details. Consists of 4 tiles of Sentinel 2 imagery from 13 different points in time. @@ -54,29 +60,12 @@ class CV4AKenyaCropType(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. """ - collection_ids = [ - 'ref_african_crops_kenya_02_labels', - 'ref_african_crops_kenya_02_source', - ] - image_meta = { - 'filename': 'ref_african_crops_kenya_02_source.tar.gz', - 'md5': '9c2004782f6dc83abb1bf45ba4d0da46', - } - target_meta = { - 'filename': 'ref_african_crops_kenya_02_labels.tar.gz', - 'md5': '93949abd0ae82ba564f5a933cefd8215', - } - - tile_names = [ - 'ref_african_crops_kenya_02_tile_00', - 'ref_african_crops_kenya_02_tile_01', - 'ref_african_crops_kenya_02_tile_02', - 'ref_african_crops_kenya_02_tile_03', - ] + url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge' + tiles = list(map(str, range(4))) dates = [ '20190606', '20190701', @@ -92,7 +81,7 @@ class CV4AKenyaCropType(NonGeoDataset): '20191004', '20191103', ] - band_names = ( + all_bands = ( 'B01', 'B02', 'B03', @@ -107,7 +96,6 @@ class CV4AKenyaCropType(NonGeoDataset): 'B12', 'CLD', ) - rgb_bands = ['B04', 'B03', 'B02'] # Same for all tiles @@ -119,12 +107,9 @@ def __init__( root: str = 'data', chip_size: int = 256, stride: int = 128, - bands: tuple[str, ...] = band_names, + bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, - verbose: bool = False, ) -> None: """Initialize a new CV4A Kenya Crop Type Dataset instance. @@ -137,32 +122,25 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - verbose: if True, print messages when new tiles are loaded Raises: + AssertionError: If *bands* are invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ - self._validate_bands(bands) + assert set(bands) <= set(self.all_bands) self.root = root self.chip_size = chip_size self.stride = stride self.bands = bands self.transforms = transforms - self.checksum = checksum - self.verbose = verbose + self.download = download - if download: - self._download(api_key) - - if not self._check_integrity(): - raise DatasetNotFoundError(self) + self._verify() # Calculate the indices that we will use over all tiles self.chips_metadata = [] - for tile_index in range(len(self.tile_names)): + for tile_index in range(len(self.tiles)): for y in list(range(0, self.tile_height - self.chip_size, stride)) + [ self.tile_height - self.chip_size ]: @@ -181,10 +159,10 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data, labels, field ids, and metadata at that index """ tile_index, y, x = self.chips_metadata[index] - tile_name = self.tile_names[tile_index] + tile = self.tiles[tile_index] - img = self._load_all_image_tiles(tile_name, self.bands) - labels, field_ids = self._load_label_tile(tile_name) + img = self._load_all_image_tiles(tile) + labels, field_ids = self._load_label_tile(tile) img = img[:, :, y : y + self.chip_size, x : x + self.chip_size] labels = labels[y : y + self.chip_size, x : x + self.chip_size] @@ -213,193 +191,94 @@ def __len__(self) -> int: return len(self.chips_metadata) @lru_cache(maxsize=128) - def _load_label_tile(self, tile_name: str) -> tuple[Tensor, Tensor]: + def _load_label_tile(self, tile: str) -> tuple[Tensor, Tensor]: """Load a single _tile_ of labels and field_ids. Args: - tile_name: name of tile to load + tile: name of tile to load Returns: tuple of labels and field ids - - Raises: - AssertionError: if ``tile_name`` is invalid """ - assert tile_name in self.tile_names - - if self.verbose: - print(f'Loading labels/field_ids for {tile_name}') - - directory = os.path.join( - self.root, 'ref_african_crops_kenya_02_labels', tile_name + '_label' - ) + directory = os.path.join(self.root, 'data', tile) - with Image.open(os.path.join(directory, 'labels.tif')) as img: + with Image.open(os.path.join(directory, f'{tile}_label.tif')) as img: array: np.typing.NDArray[np.int_] = np.array(img) labels = torch.from_numpy(array) - with Image.open(os.path.join(directory, 'field_ids.tif')) as img: + with Image.open(os.path.join(directory, f'{tile}_field_id.tif')) as img: array = np.array(img) field_ids = torch.from_numpy(array) - return (labels, field_ids) - - def _validate_bands(self, bands: tuple[str, ...]) -> None: - """Validate list of bands. - - Args: - bands: user-provided tuple of bands to load - - Raises: - AssertionError: if ``bands`` is not a tuple - ValueError: if an invalid band name is provided - """ - assert isinstance(bands, tuple), 'The list of bands must be a tuple' - for band in bands: - if band not in self.band_names: - raise ValueError(f"'{band}' is an invalid band name.") + return labels, field_ids @lru_cache(maxsize=128) - def _load_all_image_tiles( - self, tile_name: str, bands: tuple[str, ...] = band_names - ) -> Tensor: + def _load_all_image_tiles(self, tile: str) -> Tensor: """Load all the imagery (across time) for a single _tile_. Optionally allows for subsetting of the bands that are loaded. Args: - tile_name: name of tile to load - bands: tuple of bands to load + tile: name of tile to load Returns: imagery of shape (13, number of bands, 3035, 2016) where 13 is the number of - points in time, 3035 is the tile height, and 2016 is the tile width - - Raises: - AssertionError: if ``tile_name`` is invalid + points in time, 3035 is the tile height, and 2016 is the tile width """ - assert tile_name in self.tile_names - - if self.verbose: - print(f'Loading all imagery for {tile_name}') - img = torch.zeros( len(self.dates), - len(bands), + len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32, ) for date_index, date in enumerate(self.dates): - img[date_index] = self._load_single_image_tile(tile_name, date, self.bands) + img[date_index] = self._load_single_image_tile(tile, date) return img @lru_cache(maxsize=128) - def _load_single_image_tile( - self, tile_name: str, date: str, bands: tuple[str, ...] - ) -> Tensor: + def _load_single_image_tile(self, tile: str, date: str) -> Tensor: """Load the imagery for a single tile for a single date. - Optionally allows for subsetting of the bands that are loaded. - Args: - tile_name: name of tile to load + tile: name of tile to load date: date of tile to load - bands: bands to load Returns: array containing a single image tile - - Raises: - AssertionError: if ``tile_name`` or ``date`` is invalid """ - assert tile_name in self.tile_names - assert date in self.dates - - if self.verbose: - print(f'Loading imagery for {tile_name} at {date}') - + directory = os.path.join(self.root, 'data', tile, date) img = torch.zeros( - len(bands), self.tile_height, self.tile_width, dtype=torch.float32 + len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32 ) for band_index, band_name in enumerate(self.bands): - filepath = os.path.join( - self.root, - 'ref_african_crops_kenya_02_source', - f'{tile_name}_{date}', - f'{band_name}.tif', - ) + filepath = os.path.join(directory, f'{tile}_{band_name}_{date}.tif') with Image.open(filepath) as band_img: array: np.typing.NDArray[np.int_] = np.array(band_img) img[band_index] = torch.from_numpy(array) return img - def _check_integrity(self) -> bool: - """Check integrity of dataset. - - Returns: - True if dataset files are found and/or MD5s match, else False - """ - images: bool = check_integrity( - os.path.join(self.root, self.image_meta['filename']), - self.image_meta['md5'] if self.checksum else None, - ) - - targets: bool = check_integrity( - os.path.join(self.root, self.target_meta['filename']), - self.target_meta['md5'] if self.checksum else None, - ) - - return images and targets - - def get_splits(self) -> tuple[list[int], list[int]]: - """Get the field_ids for the train/test splits from the dataset directory. - - Returns: - list of training field_ids and list of testing field_ids - """ - train_field_ids = [] - test_field_ids = [] - splits_fn = os.path.join( - self.root, - 'ref_african_crops_kenya_02_labels', - '_common', - 'field_train_test_ids.csv', - ) - - with open(splits_fn, newline='') as f: - reader = csv.reader(f) - - # Skip header row - next(reader) - - for row in reader: - train_field_ids.append(int(row[0])) - if row[1]: - test_field_ids.append(int(row[1])) - - return train_field_ids, test_field_ids - - def _download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. - - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - """ - if self._check_integrity(): - print('Files already downloaded and verified') + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + if os.path.exists(os.path.join(self.root, 'FieldIds.csv')): return - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, api_key) + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() - image_archive_path = os.path.join(self.root, self.image_meta['filename']) - target_archive_path = os.path.join(self.root, self.target_meta['filename']) - for fn in [image_archive_path, target_archive_path]: - extract_archive(fn, self.root) + def _download(self) -> None: + """Download the dataset.""" + os.makedirs(self.root, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', self.url, self.root, '--recursive=true') def plot( self, @@ -439,13 +318,7 @@ def plot( image, mask = sample['image'], sample['mask'] - assert time_step <= image.shape[0] - 1, ( - 'The specified time step' - f' does not exist, image only contains {image.shape[0]} time' - ' instances.' - ) - - image = image[time_step, rgb_indices, :, :] + image = image[time_step, rgb_indices] fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5)) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index eccca9d7314..747463b69de 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -3,7 +3,6 @@ """Tropical Cyclone Wind Estimation Competition dataset.""" -import json import os from collections.abc import Callable from functools import lru_cache @@ -11,6 +10,7 @@ import matplotlib.pyplot as plt import numpy as np +import pandas as pd import torch from matplotlib.figure import Figure from PIL import Image @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import which class TropicalCyclone(NonGeoDataset): @@ -26,10 +26,9 @@ class TropicalCyclone(NonGeoDataset): A collection of tropical storms in the Atlantic and East Pacific Oceans from 2000 to 2019 with corresponding maximum sustained surface wind speed. This dataset is split - into training and test categories for the purpose of a competition. - - See https://www.drivendata.org/competitions/72/predict-wind-speeds/ for more - information about the competition. + into training and test categories for the purpose of a competition. Read more about + the competition here: + https://www.drivendata.org/competitions/72/predict-wind-speeds/. If you use this dataset in your research, please cite the following paper: @@ -39,31 +38,17 @@ class TropicalCyclone(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. .. versionchanged:: 0.4 Class name changed from TropicalCycloneWindEstimation to TropicalCyclone to be consistent with TropicalCycloneDataModule. """ - collection_id = 'nasa_tropical_storm_competition' - collection_ids = [ - 'nasa_tropical_storm_competition_train_source', - 'nasa_tropical_storm_competition_test_source', - 'nasa_tropical_storm_competition_train_labels', - 'nasa_tropical_storm_competition_test_labels', - ] - md5s = { - 'train': { - 'source': '97e913667a398704ea8d28196d91dad6', - 'labels': '97d02608b74c82ffe7496a9404a30413', - }, - 'test': { - 'source': '8d88099e4b310feb7781d776a6e1dcef', - 'labels': 'd910c430f90153c1f78a99cbc08e7bd0', - }, - } + url = ( + 'https://radiantearth.blob.core.windows.net/mlhub/nasa-tropical-storm-challenge' + ) size = 366 def __init__( @@ -72,10 +57,8 @@ def __init__( split: str = 'train', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, ) -> None: - """Initialize a new Tropical Cyclone Wind Estimation Competition Dataset. + """Initialize a new TropicalCyclone instance. Args: root: root directory where dataset can be found @@ -83,30 +66,26 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``split`` argument is invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert split in self.md5s + assert split in {'train', 'test'} self.root = root self.split = split self.transforms = transforms - self.checksum = checksum + self.download = download - if download: - self._download(api_key) + self.filename = f'{split}_set' + if split == 'train': + self.filename = f'{split}ing_set' - if not self._check_integrity(): - raise DatasetNotFoundError(self) + self._verify() - output_dir = '_'.join([self.collection_id, split, 'source']) - filename = os.path.join(root, output_dir, 'collection.json') - with open(filename) as f: - self.collection = json.load(f)['links'] + self.features = pd.read_csv(os.path.join(root, f'{self.filename}_features.csv')) + self.labels = pd.read_csv(os.path.join(root, f'{self.filename}_labels.csv')) def __getitem__(self, index: int) -> dict[str, Any]: """Return an index within the dataset. @@ -117,15 +96,14 @@ def __getitem__(self, index: int) -> dict[str, Any]: Returns: data, labels, field ids, and metadata at that index """ - source_id = os.path.split(self.collection[index]['href'])[0] - directory = os.path.join( - self.root, - '_'.join([self.collection_id, self.split, '{0}']), - source_id.replace('source', '{0}'), - ) + sample = { + 'relative_time': torch.tensor(self.features.iat[index, 2]), + 'ocean': torch.tensor(self.features.iat[index, 3]), + 'label': torch.tensor(self.labels.iat[index, 1]), + } - sample: dict[str, Any] = {'image': self._load_image(directory)} - sample.update(self._load_features(directory)) + image_id = self.labels.iat[index, 0] + sample['image'] = self._load_image(image_id) if self.transforms is not None: sample = self.transforms(sample) @@ -138,19 +116,19 @@ def __len__(self) -> int: Returns: length of the dataset """ - return len(self.collection) + return len(self.labels) @lru_cache - def _load_image(self, directory: str) -> Tensor: + def _load_image(self, image_id: str) -> Tensor: """Load a single image. Args: - directory: directory containing image + image_id: Filename of the image. Returns: the image """ - filename = os.path.join(directory.format('source'), 'image.jpg') + filename = os.path.join(self.root, self.split, f'{image_id}.jpg') with Image.open(filename) as img: if img.height != self.size or img.width != self.size: # Moved in PIL 9.1.0 @@ -164,61 +142,30 @@ def _load_image(self, directory: str) -> Tensor: tensor = tensor.permute((2, 0, 1)).float() return tensor - def _load_features(self, directory: str) -> dict[str, Any]: - """Load features for a single image. - - Args: - directory: directory containing image - - Returns: - the features - """ - filename = os.path.join(directory.format('source'), 'features.json') - with open(filename) as f: - features: dict[str, Any] = json.load(f) - - filename = os.path.join(directory.format('labels'), 'labels.json') - with open(filename) as f: - features.update(json.load(f)) - - features['relative_time'] = int(features['relative_time']) - features['ocean'] = int(features['ocean']) - features['label'] = torch.tensor(int(features['wind_speed'])).float() - - return features - - def _check_integrity(self) -> bool: - """Check integrity of dataset. - - Returns: - True if dataset files are found and/or MD5s match, else False - """ - for split, resources in self.md5s.items(): - for resource_type, md5 in resources.items(): - filename = '_'.join([self.collection_id, split, resource_type]) - filename = os.path.join(self.root, filename + '.tar.gz') - if not check_integrity(filename, md5 if self.checksum else None): - return False - return True - - def _download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. - - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - """ - if self._check_integrity(): - print('Files already downloaded and verified') + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + files = [f'{self.filename}_features.csv', f'{self.filename}_labels.csv'] + exists = [os.path.exists(os.path.join(self.root, file)) for file in files] + if all(exists): return - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, api_key) + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) - for split, resources in self.md5s.items(): - for resource_type in resources: - filename = '_'.join([self.collection_id, split, resource_type]) - filename = os.path.join(self.root, filename) + '.tar.gz' - extract_archive(filename, self.root) + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + directory = os.path.join(self.root, self.split) + os.makedirs(directory, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', f'{self.url}/{self.split}', directory, '--recursive=true') + files = [f'{self.filename}_features.csv', f'{self.filename}_labels.csv'] + for file in files: + azcopy('copy', f'{self.url}/{file}', self.root) def plot( self, diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py index 697ddeb0fb5..b9cd1556f9f 100644 --- a/torchgeo/datasets/dfc2022.py +++ b/torchgeo/datasets/dfc2022.py @@ -235,7 +235,7 @@ def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor: the image """ with rasterio.open(path) as f: - array: np.typing.NDArray[np.float_] = f.read( + array: np.typing.NDArray[np.float64] = f.read( out_shape=shape, out_dtype='float32', resampling=Resampling.bilinear ) tensor = torch.from_numpy(array) diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index daa1987e3c8..0082dd152b9 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -243,10 +243,12 @@ def plot( fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4)) - def apply_cmap(arr: 'np.typing.NDArray[Any]') -> 'np.typing.NDArray[np.float_]': + def apply_cmap( + arr: 'np.typing.NDArray[Any]', + ) -> 'np.typing.NDArray[np.float64]': # Color 0 as black, while applying default color map for the class indices. cmap = plt.get_cmap('viridis') - im: np.typing.NDArray[np.float_] = cmap(arr / len(self.class_map)) + im: np.typing.NDArray[np.float64] = cmap(arr / len(self.class_map)) im[arr == 0] = 0 return im diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 9e6bc5a8909..982917fcdd6 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -41,7 +41,7 @@ class EuroSAT(NonGeoClassificationDataset): * Permanent Crop * Residential Buildings * River - * SeaLake + * Sea & Lake This dataset uses the train/val/test splits defined in the "In-domain representation learning for remote sensing" paper: diff --git a/torchgeo/datasets/patternnet.py b/torchgeo/datasets/patternnet.py index f34afa443df..4b64a1b488e 100644 --- a/torchgeo/datasets/patternnet.py +++ b/torchgeo/datasets/patternnet.py @@ -78,7 +78,7 @@ class PatternNet(NonGeoClassificationDataset): * https://doi.org/10.1016/j.isprsjprs.2018.01.004 """ - url = 'https://drive.google.com/file/d/127lxXYqzO6Bd0yZhvEbgIfz95HaEnr9K' + url = 'https://hf.co/datasets/torchgeo/PatternNet/resolve/2dbd901b00e301967a5c5146b25454f5d3455ad0/PatternNet.zip' md5 = '96d54b3224c5350a98d55d5a7e6984ad' filename = 'PatternNet.zip' directory = os.path.join('PatternNet', 'images') diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index cd5adff76c8..fb066424b1a 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -91,6 +91,13 @@ class RESISC45(NonGeoClassificationDataset): If you use this dataset in your research, please cite the following paper: * https://doi.org/10.1109/jproc.2017.2675998 + + .. note:: + + This dataset requires the following additional library to be installed: + + * `rarfile `_ to extract the dataset, + which is stored in a RAR file """ url = 'https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv' diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py index 9439e525ab3..07a496ea974 100644 --- a/torchgeo/datasets/rwanda_field_boundary.py +++ b/torchgeo/datasets/rwanda_field_boundary.py @@ -3,6 +3,7 @@ """Rwanda Field Boundary Competition dataset.""" +import glob import os from collections.abc import Callable, Sequence @@ -16,11 +17,11 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import which class RwandaFieldBoundary(NonGeoDataset): - r"""Rwanda Field Boundary Competition dataset. + """Rwanda Field Boundary Competition dataset. This dataset contains field boundaries for smallholder farms in eastern Rwanda. The Nasa Harvest program funded a team of annotators from TaQadam to label Planet @@ -46,40 +47,20 @@ class RwandaFieldBoundary(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. .. versionadded:: 0.5 """ - dataset_id = 'nasa_rwanda_field_boundary_competition' - collection_ids = [ - 'nasa_rwanda_field_boundary_competition_source_train', - 'nasa_rwanda_field_boundary_competition_labels_train', - 'nasa_rwanda_field_boundary_competition_source_test', - ] - number_of_patches_per_split = {'train': 57, 'test': 13} - - filenames = { - 'train_images': 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', - 'test_images': 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', - 'train_labels': 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', - } - md5s = { - 'train_images': '1f9ec08038218e67e11f82a86849b333', - 'test_images': '17bb0e56eedde2e7a43c57aa908dc125', - 'train_labels': '10e4eb761523c57b6d3bdf9394004f5f', - } + url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition' + splits = {'train': 57, 'test': 13} dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12') - all_bands = ('B01', 'B02', 'B03', 'B04') rgb_bands = ('B03', 'B02', 'B01') - classes = ['No field-boundary', 'Field-boundary'] - splits = ['train', 'test'] - def __init__( self, root: str = 'data', @@ -87,8 +68,6 @@ def __init__( bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, ) -> None: """Initialize a new RwandaFieldBoundary instance. @@ -99,49 +78,29 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: + AssertionError: If *split* or *bands* are invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ - self._validate_bands(bands) assert split in self.splits - if download and api_key is None: - raise RuntimeError('Must provide an API key to download the dataset') + assert set(bands) <= set(self.all_bands) + self.root = root + self.split = split self.bands = bands self.transforms = transforms - self.split = split self.download = download - self.api_key = api_key - self.checksum = checksum + self._verify() - self.image_filenames: list[list[list[str]]] = [] - self.mask_filenames: list[str] = [] - for i in range(self.number_of_patches_per_split[split]): - dates = [] - for date in self.dates: - patch = [] - for band in self.bands: - fn = os.path.join( - self.root, - f'nasa_rwanda_field_boundary_competition_source_{split}', - f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501 - f'{band}.tif', - ) - patch.append(fn) - dates.append(patch) - self.image_filenames.append(dates) - self.mask_filenames.append( - os.path.join( - self.root, - f'nasa_rwanda_field_boundary_competition_labels_{split}', - f'nasa_rwanda_field_boundary_competition_labels_{split}_{i:02d}', - 'raster_labels.tif', - ) - ) + def __len__(self) -> int: + """Return the number of chips in the dataset. + + Returns: + length of the dataset + """ + return self.splits[self.split] def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -150,83 +109,34 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: index: index to return Returns: - a dict containing image, mask, transform, crs, and metadata at index. + a dict containing image and mask at index. """ - img_fns = self.image_filenames[index] - mask_fn = self.mask_filenames[index] - - imgs = [] - for date_fns in img_fns: - bands = [] - for band_fn in date_fns: - with rasterio.open(band_fn) as f: - bands.append(f.read(1).astype(np.int32)) - imgs.append(bands) - img = torch.from_numpy(np.array(imgs)) - - sample = {'image': img} + images = [] + for date in self.dates: + patches = [] + for band in self.bands: + path = os.path.join(self.root, 'source', self.split, date) + with rasterio.open(os.path.join(path, f'{index:02}_{band}.tif')) as src: + patches.append(src.read(1).astype(np.float32)) + images.append(patches) + sample = {'image': torch.from_numpy(np.array(images))} if self.split == 'train': - with rasterio.open(mask_fn) as f: - mask = f.read(1) - mask = torch.from_numpy(mask) - sample['mask'] = mask + path = os.path.join(self.root, 'labels', self.split) + with rasterio.open(os.path.join(path, f'{index:02}.tif')) as src: + sample['mask'] = torch.from_numpy(src.read(1).astype(np.int64)) if self.transforms is not None: sample = self.transforms(sample) return sample - def __len__(self) -> int: - """Return the number of chips in the dataset. - - Returns: - length of the dataset - """ - return len(self.image_filenames) - - def _validate_bands(self, bands: Sequence[str]) -> None: - """Validate list of bands. - - Args: - bands: user-provided sequence of bands to load - - Raises: - ValueError: if an invalid band name is provided - """ - for band in bands: - if band not in self.all_bands: - raise ValueError(f"'{band}' is an invalid band name.") - def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the subdirectories already exist and have the correct number of files - checks = [] - for split, num_patches in self.number_of_patches_per_split.items(): - path = os.path.join( - self.root, f'nasa_rwanda_field_boundary_competition_source_{split}' - ) - if os.path.exists(path): - num_files = len(os.listdir(path)) - # 6 dates + 1 collection.json file - checks.append(num_files == (num_patches * 6) + 1) - else: - checks.append(False) - - if all(checks): - return - - # Check if tar file already exists (if so then extract) - have_all_files = True - for group in ['train_images', 'train_labels', 'test_images']: - filepath = os.path.join(self.root, self.filenames[group]) - if os.path.exists(filepath): - if self.checksum and not check_integrity(filepath, self.md5s[group]): - raise RuntimeError('Dataset found, but corrupted.') - extract_archive(filepath) - else: - have_all_files = False - if have_all_files: + path = os.path.join(self.root, 'source', self.split, '*', '*.tif') + expected = len(self.dates) * self.splits[self.split] * len(self.all_bands) + if len(glob.glob(path)) == expected: return # Check if the user requested to download the dataset @@ -237,15 +147,10 @@ def _verify(self) -> None: self._download() def _download(self) -> None: - """Download the dataset and extract it.""" - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, self.api_key) - - for group in ['train_images', 'train_labels', 'test_images']: - filepath = os.path.join(self.root, self.filenames[group]) - if self.checksum and not check_integrity(filepath, self.md5s[group]): - raise RuntimeError('Dataset not found or corrupted.') - extract_archive(filepath, self.root) + """Download the dataset.""" + os.makedirs(self.root, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', self.url, self.root, '--recursive=true') def plot( self, diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index 2aa23f3ba75..ea36b974da1 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -169,7 +169,7 @@ def _load_patch(self, root: str, subdir: str) -> Tensor: # what could be sped up throughout later. There is also a potential # slowdown here from converting to/from a PIL Image just to resize. # https://gist.github.com/calebrob6/748045ac8d844154067b2eefa47de92f - pil_image = Image.fromarray(band_data) # type: ignore[no-untyped-call] + pil_image = Image.fromarray(band_data) # Moved in PIL 9.1.0 try: resample = Image.Resampling.BILINEAR diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 24a7b4e8b90..d13d84dcf15 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -82,7 +82,7 @@ def __init__( weights: Initial model weights. True for ImageNet weights, False or None for random weights. in_channels: Number of input channels to model. - num_classes: Number of prediction classes. + num_classes: Number of prediction classes (including the background). trainable_layers: Number of trainable layers. lr: Learning rate for optimizer. patience: Patience for learning rate scheduler. diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index dad26635c64..afd71521002 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -54,7 +54,7 @@ def __init__( model does not support pretrained weights. Pretrained ViT weight enums are not supported yet. in_channels: Number of input channels to model. - num_classes: Number of prediction classes. + num_classes: Number of prediction classes (including the background). num_filters: Number of filters. Only applicable when model='fcn'. loss: Name of the loss function, currently supports 'ce', 'jaccard' or 'focal' loss.