diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml deleted file mode 100644 index 179e13377ad..00000000000 --- a/.github/workflows/coverage.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: coverage -on: - workflow_run: - workflows: [tests] - types: [completed] -jobs: - codecov: - name: codecov - runs-on: ubuntu-latest - steps: - - name: Clone repo - uses: actions/checkout@v4.1.1 - - name: Download coverage artifacts - uses: actions/download-artifact@v4.1.1 - with: - github-token: ${{ github.token }} - run-id: ${{ github.event.workflow_run.id }} - - name: Report coverage - uses: codecov/codecov-action@v4.0.1 - with: - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 6c6c680ae63..dce17ec3f94 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -16,7 +16,7 @@ jobs: - name: Set up python uses: actions/setup-python@v5.0.0 with: - python-version: '3.11' + python-version: '3.12' - name: Cache dependencies uses: actions/cache@v4.0.0 id: cache @@ -44,7 +44,7 @@ jobs: - name: Set up python uses: actions/setup-python@v5.0.0 with: - python-version: '3.11' + python-version: '3.12' - name: Cache dependencies uses: actions/cache@v4.0.0 id: cache @@ -72,7 +72,7 @@ jobs: - name: Set up python uses: actions/setup-python@v5.0.0 with: - python-version: '3.11' + python-version: '3.12' - name: Cache dependencies uses: actions/cache@v4.0.0 id: cache diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml index 7d6236c9408..8bd60000d5f 100644 --- a/.github/workflows/style.yaml +++ b/.github/workflows/style.yaml @@ -18,7 +18,7 @@ jobs: - name: Set up python uses: actions/setup-python@v5.0.0 with: - python-version: '3.11' + python-version: '3.12' - name: Cache dependencies uses: actions/cache@v4.0.0 id: cache @@ -43,7 +43,7 @@ jobs: - name: Set up python uses: actions/setup-python@v5.0.0 with: - python-version: '3.11' + python-version: '3.12' - name: Cache dependencies uses: actions/cache@v4.0.0 id: cache @@ -68,7 +68,7 @@ jobs: - name: Set up python uses: actions/setup-python@v5.0.0 with: - python-version: '3.11' + python-version: '3.12' - name: Cache dependencies uses: actions/cache@v4.0.0 id: cache @@ -93,7 +93,7 @@ jobs: - name: Set up python uses: actions/setup-python@v5.0.0 with: - python-version: '3.11' + python-version: '3.12' - name: Cache dependencies uses: actions/cache@v4.0.0 id: cache @@ -118,7 +118,7 @@ jobs: - name: Set up python uses: actions/setup-python@v5.0.0 with: - python-version: '3.11' + python-version: '3.12' - name: Cache dependencies uses: actions/cache@v4.0.0 id: cache @@ -143,7 +143,7 @@ jobs: - name: Set up python uses: actions/setup-python@v5.0.0 with: - python-version: '3.11' + python-version: '3.12' - name: Cache dependencies uses: actions/cache@v4.0.0 id: cache diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index ec5a68feb25..d72a114f092 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -17,7 +17,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ['3.9', '3.10', '3.11'] + python-version: ['3.9', '3.10', '3.11', '3.12'] steps: - name: Clone repo uses: actions/checkout@v4.1.1 @@ -31,7 +31,7 @@ jobs: with: path: ${{ env.pythonLocation }} key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/datasets.txt') }}-${{ hashFiles('requirements/tests.txt') }} - if: ${{ ! (runner.os == 'macOS' && matrix.python-version == '3.11') }} + if: ${{ ! (runner.os == 'macOS' && (matrix.python-version == '3.11' || matrix.python-version == '3.12')) }} - name: Setup headless display for pyvista uses: pyvista/setup-headless-display-action@v2 - name: Install apt dependencies (Linux) @@ -56,11 +56,10 @@ jobs: run: | pytest --cov=torchgeo --cov-report=xml --durations=10 python3 -m torchgeo --help - - name: Upload coverage artifact - uses: actions/upload-artifact@v4.3.0 + - name: Report coverage + uses: codecov/codecov-action@v3.1.6 with: - name: coverage_${{ matrix.os }}_py-${{ matrix.python-version }} - path: coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} minimum: name: minimum runs-on: ubuntu-latest @@ -96,25 +95,10 @@ jobs: run: | pytest --cov=torchgeo --cov-report=xml --durations=10 python3 -m torchgeo --help - - name: Upload coverage artifact - uses: actions/upload-artifact@v4.3.0 - with: - name: coverage_minimum - path: coverage.xml - codecov: - name: codecov - runs-on: ubuntu-latest - needs: [latest, minimum] - steps: - - name: Clone repo - uses: actions/checkout@v4.1.1 - - name: Download coverage artifacts - uses: actions/download-artifact@v4.1.1 - name: Report coverage uses: codecov/codecov-action@v3.1.6 with: token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }} cancel-in-progress: true diff --git a/.github/workflows/tutorials.yaml b/.github/workflows/tutorials.yaml index 5833bccd4f6..800a3ca46a7 100644 --- a/.github/workflows/tutorials.yaml +++ b/.github/workflows/tutorials.yaml @@ -20,17 +20,17 @@ jobs: - name: Set up python uses: actions/setup-python@v5.0.0 with: - python-version: '3.11' + python-version: '3.12' - name: Cache dependencies uses: actions/cache@v4.0.0 id: cache with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-tutorials + key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/docs.txt') }}-${{ hashFiles('requirements/tests.txt') }}-tutorials - name: Install pip dependencies if: steps.cache.outputs.cache-hit != 'true' run: | - pip install .[docs,tests] planetary_computer pystac + pip install -r requirements/required.txt -r requirements/docs.txt -r requirements/tests.txt planetary_computer pystac pip cache purge - name: List pip dependencies run: pip list diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 534c7e24017..9cfd9895a23 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,7 +8,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.11" + python: "3.12" # Configuration of the Python environment to be used python: diff --git a/README.md b/README.md index 36ae3aa37eb..ec31fcb7753 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,7 @@ from torchgeo.models import ResNet18_Weights weights = ResNet18_Weights.SENTINEL2_ALL_MOCO model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10) -model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False) +model.load_state_dict(weights.get_state_dict(progress=True), strict=False) ``` These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the `weights` argument. For a notebook example, see this [tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html). diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 645acbacc47..3f9539db034 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -23,6 +23,11 @@ Aboveground Woody Biomass .. autoclass:: AbovegroundLiveWoodyBiomassDensity +AgriFieldNet +^^^^^^^^^^^^ + +.. autoclass:: AgriFieldNet + Airphen ^^^^^^^ diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index ed7655e843d..ac611c00e45 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -1,5 +1,6 @@ Dataset,Type,Source,License,Size (px),Resolution (m) `Aboveground Woody Biomass`_,Masks,"Landsat, LiDAR","CC-BY-4.0","40,000x40,000",30 +`AgriFieldNet`_,"Imagery, Masks",Sentinel-2,"CC-BY-4.0","256x256",10 `Airphen`_,Imagery,Airphen,-,"1,280x960",0.047--0.09 `Aster Global DEM`_,Masks,Aster,"public domain","3,601x3,601",30 `Canadian Building Footprints`_,Geometries,Bing Imagery,"ODbL-1.0",-,- diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb index 26d97fcbc6c..e15dd1ebc16 100644 --- a/docs/tutorials/pretrained_weights.ipynb +++ b/docs/tutorials/pretrained_weights.ipynb @@ -228,7 +228,7 @@ "source": [ "in_chans = weights.meta[\"in_chans\"]\n", "model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n", - "model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)" + "model.load_state_dict(weights.get_state_dict(progress=True), strict=False)" ] }, { diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb index 80a53f52643..ca76d9dc5a9 100644 --- a/docs/tutorials/transforms.ipynb +++ b/docs/tutorials/transforms.ipynb @@ -501,7 +501,7 @@ "id": "w4ZbjxPyHoiB" }, "source": [ - "It's even possible to chain indices along with augmentations from kornia for a single callable during training." + "It's even possible to chain indices along with augmentations from Kornia for a single callable during training." ] }, { diff --git a/experiments/torchgeo/run_resisc45_experiments.py b/experiments/torchgeo/run_resisc45_experiments.py index e4915bc0c64..1049edf749b 100755 --- a/experiments/torchgeo/run_resisc45_experiments.py +++ b/experiments/torchgeo/run_resisc45_experiments.py @@ -37,7 +37,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: for model, lr, loss, weights in itertools.product( model_options, lr_options, loss_options, weight_options ): - experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_','-')}" + experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}" output_dir = os.path.join("output", "resisc45_experiments") log_dir = os.path.join(output_dir, "logs") diff --git a/experiments/torchgeo/run_so2sat_experiments.py b/experiments/torchgeo/run_so2sat_experiments.py index d574d0987c0..fd8562ca345 100755 --- a/experiments/torchgeo/run_so2sat_experiments.py +++ b/experiments/torchgeo/run_so2sat_experiments.py @@ -37,7 +37,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: for model, lr, loss, weights in itertools.product( model_options, lr_options, loss_options, weight_options ): - experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_','-')}" + experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}" output_dir = os.path.join("output", "so2sat_experiments") log_dir = os.path.join(output_dir, "logs") diff --git a/experiments/torchgeo/run_so2sat_seed_experiments.py b/experiments/torchgeo/run_so2sat_seed_experiments.py index 90e6d274910..2f88eba3c75 100755 --- a/experiments/torchgeo/run_so2sat_seed_experiments.py +++ b/experiments/torchgeo/run_so2sat_seed_experiments.py @@ -38,7 +38,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: for model, lr, loss, weights, seed in itertools.product( model_options, lr_options, loss_options, weight_options, seeds ): - experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_','-')}_{seed}" + experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}_{seed}" output_dir = os.path.join("output", "so2sat_seed_experiments") log_dir = os.path.join(output_dir, "logs") diff --git a/pyproject.toml b/pyproject.toml index d4a8c7ae955..5b0a12c4830 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: GIS", ] diff --git a/requirements/datasets.txt b/requirements/datasets.txt index a52aa575d9e..d2caa09ad30 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -3,7 +3,7 @@ h5py==3.10.0 laspy==2.5.3 opencv-python==4.9.0.80 pycocotools==2.0.7 -pyvista==0.43.2 +pyvista==0.43.3 radiant-mlhub==0.4.1 rarfile==4.1 scikit-image==0.22.0 diff --git a/requirements/docs.txt b/requirements/docs.txt index 898ddd75ed4..9cfb14b7a3d 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,4 +1,4 @@ # docs -ipywidgets==8.1.1 +ipywidgets==8.1.2 nbsphinx==0.9.3 sphinx==5.3.0 diff --git a/requirements/required.txt b/requirements/required.txt index 7631a543e30..e7cb2af6ac1 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -6,9 +6,9 @@ einops==0.7.0 fiona==1.9.5 kornia==0.7.1 lightly==1.4.25 -lightning[pytorch-extra]==2.1.4 +lightning[pytorch-extra]==2.2.0 matplotlib==3.8.2 -numpy==1.26.3 +numpy==1.26.4 pandas==2.2.0 pillow==10.2.0 pyproj==3.6.1 diff --git a/tests/data/agrifieldnet/data.py b/tests/data/agrifieldnet/data.py new file mode 100644 index 00000000000..e0b4d0e256c --- /dev/null +++ b/tests/data/agrifieldnet/data.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine + + +def generate_test_data(paths: str) -> str: + """Create test data archive for AgriFieldNet dataset. + + Args: + paths: path to store test data + n_samples: number of samples. + + Returns: + md5 hash of created archive + """ + dtype = np.uint8 + dtype_max = np.iinfo(dtype).max + + SIZE = 32 + + np.random.seed(0) + + bands = ( + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B11", + "B12", + ) + + profile = { + "dtype": dtype, + "width": SIZE, + "height": SIZE, + "count": 1, + "crs": CRS.from_epsg(32644), + "transform": Affine(10.0, 0.0, 535840.0, 0.0, -10.0, 3079680.0), + } + + source_dir = os.path.join(paths, "source") + train_mask_dir = os.path.join(paths, "train_labels") + test_field_dir = os.path.join(paths, "test_labels") + + os.makedirs(source_dir, exist_ok=True) + os.makedirs(train_mask_dir, exist_ok=True) + os.makedirs(test_field_dir, exist_ok=True) + + source_unique_folder_ids = ["32407", "8641e", "a419f", "eac11", "ff450"] + train_folder_ids = source_unique_folder_ids[0:5] + test_folder_ids = source_unique_folder_ids[3:5] + + for id in source_unique_folder_ids: + directory = os.path.join( + source_dir, "ref_agrifieldnet_competition_v1_source_" + id + ) + os.makedirs(directory, exist_ok=True) + + for band in bands: + train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype) + path = os.path.join( + directory, f"ref_agrifieldnet_competition_v1_source_{id}_{band}_10m.tif" + ) + with rasterio.open(path, "w", **profile) as src: + src.write(train_arr, 1) + + for id in train_folder_ids: + train_mask_arr = np.random.randint(size=(SIZE, SIZE), low=0, high=6) + path = os.path.join( + train_mask_dir, f"ref_agrifieldnet_competition_v1_labels_train_{id}.tif" + ) + with rasterio.open(path, "w", **profile) as src: + src.write(train_mask_arr, 1) + + train_field_arr = np.random.randint(20, size=(SIZE, SIZE), dtype=np.uint16) + path = os.path.join( + train_mask_dir, + f"ref_agrifieldnet_competition_v1_labels_train_{id}_field_ids.tif", + ) + with rasterio.open(path, "w", **profile) as src: + src.write(train_field_arr, 1) + + for id in test_folder_ids: + test_field_arr = np.random.randint(10, 30, size=(SIZE, SIZE), dtype=np.uint16) + path = os.path.join( + test_field_dir, + f"ref_agrifieldnet_competition_v1_labels_test_{id}_field_ids.tif", + ) + with rasterio.open(path, "w", **profile) as src: + src.write(test_field_arr, 1) + + +if __name__ == "__main__": + generate_test_data(os.getcwd()) diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B01_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B01_10m.tif new file mode 100644 index 00000000000..0742b602456 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B01_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B02_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B02_10m.tif new file mode 100644 index 00000000000..bf0ae9b6617 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B02_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B03_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B03_10m.tif new file mode 100644 index 00000000000..e64a2ac631e Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B03_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B04_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B04_10m.tif new file mode 100644 index 00000000000..a3d8acef4e0 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B04_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B05_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B05_10m.tif new file mode 100644 index 00000000000..355c7bc7876 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B05_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B06_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B06_10m.tif new file mode 100644 index 00000000000..9a9b2daf1c6 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B06_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B07_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B07_10m.tif new file mode 100644 index 00000000000..1af11ac91f5 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B07_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B08_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B08_10m.tif new file mode 100644 index 00000000000..034c797567b Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B08_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B09_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B09_10m.tif new file mode 100644 index 00000000000..a89dee1c684 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B09_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B11_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B11_10m.tif new file mode 100644 index 00000000000..f7cdc4b45e8 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B11_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B12_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B12_10m.tif new file mode 100644 index 00000000000..6dbc5efeed4 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B12_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B8A_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B8A_10m.tif new file mode 100644 index 00000000000..e496e2fbdf5 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_32407/ref_agrifieldnet_competition_v1_source_32407_B8A_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B01_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B01_10m.tif new file mode 100644 index 00000000000..6e0286c1b22 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B01_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B02_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B02_10m.tif new file mode 100644 index 00000000000..225081bf2be Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B02_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B03_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B03_10m.tif new file mode 100644 index 00000000000..e8a7ec10ea1 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B03_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B04_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B04_10m.tif new file mode 100644 index 00000000000..ac3212f3794 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B04_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B05_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B05_10m.tif new file mode 100644 index 00000000000..27d90bf00df Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B05_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B06_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B06_10m.tif new file mode 100644 index 00000000000..ef11a438781 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B06_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B07_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B07_10m.tif new file mode 100644 index 00000000000..7f7d1671d9c Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B07_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B08_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B08_10m.tif new file mode 100644 index 00000000000..e7d3cee49bd Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B08_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B09_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B09_10m.tif new file mode 100644 index 00000000000..3bcb8c42657 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B09_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B11_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B11_10m.tif new file mode 100644 index 00000000000..6b2dab469ca Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B11_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B12_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B12_10m.tif new file mode 100644 index 00000000000..e62de633c32 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B12_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B8A_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B8A_10m.tif new file mode 100644 index 00000000000..6a89eb0a188 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_8641e/ref_agrifieldnet_competition_v1_source_8641e_B8A_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B01_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B01_10m.tif new file mode 100644 index 00000000000..16d350924f9 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B01_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B02_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B02_10m.tif new file mode 100644 index 00000000000..05eebfe80d7 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B02_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B03_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B03_10m.tif new file mode 100644 index 00000000000..029b6f759c2 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B03_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B04_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B04_10m.tif new file mode 100644 index 00000000000..78a46f389e4 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B04_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B05_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B05_10m.tif new file mode 100644 index 00000000000..8247cc70628 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B05_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B06_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B06_10m.tif new file mode 100644 index 00000000000..eec2b8d95e7 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B06_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B07_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B07_10m.tif new file mode 100644 index 00000000000..b92f99a97e6 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B07_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B08_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B08_10m.tif new file mode 100644 index 00000000000..3731fdf6438 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B08_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B09_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B09_10m.tif new file mode 100644 index 00000000000..a105d1539c6 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B09_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B11_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B11_10m.tif new file mode 100644 index 00000000000..c95848df494 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B11_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B12_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B12_10m.tif new file mode 100644 index 00000000000..1480fdac909 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B12_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B8A_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B8A_10m.tif new file mode 100644 index 00000000000..dcbcc06c974 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_a419f/ref_agrifieldnet_competition_v1_source_a419f_B8A_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B01_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B01_10m.tif new file mode 100644 index 00000000000..9854d423670 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B01_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B02_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B02_10m.tif new file mode 100644 index 00000000000..7c9649abfe5 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B02_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B03_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B03_10m.tif new file mode 100644 index 00000000000..69526d1da36 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B03_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B04_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B04_10m.tif new file mode 100644 index 00000000000..facb4b67c05 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B04_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B05_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B05_10m.tif new file mode 100644 index 00000000000..abc820cd23b Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B05_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B06_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B06_10m.tif new file mode 100644 index 00000000000..2136c9f3968 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B06_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B07_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B07_10m.tif new file mode 100644 index 00000000000..97e5bab8672 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B07_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B08_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B08_10m.tif new file mode 100644 index 00000000000..e75985b7160 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B08_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B09_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B09_10m.tif new file mode 100644 index 00000000000..1d9b2c0eb37 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B09_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B11_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B11_10m.tif new file mode 100644 index 00000000000..532abf9c548 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B11_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B12_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B12_10m.tif new file mode 100644 index 00000000000..b4df1ecf49f Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B12_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B8A_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B8A_10m.tif new file mode 100644 index 00000000000..e1c736a62aa Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_eac11/ref_agrifieldnet_competition_v1_source_eac11_B8A_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B01_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B01_10m.tif new file mode 100644 index 00000000000..21c86e54836 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B01_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B02_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B02_10m.tif new file mode 100644 index 00000000000..cfa585af189 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B02_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B03_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B03_10m.tif new file mode 100644 index 00000000000..b9e4ad19d04 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B03_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B04_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B04_10m.tif new file mode 100644 index 00000000000..7f72a6c0c2e Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B04_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B05_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B05_10m.tif new file mode 100644 index 00000000000..3d57b1dba97 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B05_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B06_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B06_10m.tif new file mode 100644 index 00000000000..b0e74ae62fe Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B06_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B07_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B07_10m.tif new file mode 100644 index 00000000000..70b7d009aa5 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B07_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B08_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B08_10m.tif new file mode 100644 index 00000000000..50a09d7f650 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B08_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B09_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B09_10m.tif new file mode 100644 index 00000000000..5d7f9f16397 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B09_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B11_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B11_10m.tif new file mode 100644 index 00000000000..7fb135a4204 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B11_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B12_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B12_10m.tif new file mode 100644 index 00000000000..df850197a0e Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B12_10m.tif differ diff --git a/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B8A_10m.tif b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B8A_10m.tif new file mode 100644 index 00000000000..c0a55795956 Binary files /dev/null and b/tests/data/agrifieldnet/source/ref_agrifieldnet_competition_v1_source_ff450/ref_agrifieldnet_competition_v1_source_ff450_B8A_10m.tif differ diff --git a/tests/data/agrifieldnet/test_labels/ref_agrifieldnet_competition_v1_labels_test_eac11_field_ids.tif b/tests/data/agrifieldnet/test_labels/ref_agrifieldnet_competition_v1_labels_test_eac11_field_ids.tif new file mode 100644 index 00000000000..e6ff77ac591 Binary files /dev/null and b/tests/data/agrifieldnet/test_labels/ref_agrifieldnet_competition_v1_labels_test_eac11_field_ids.tif differ diff --git a/tests/data/agrifieldnet/test_labels/ref_agrifieldnet_competition_v1_labels_test_ff450_field_ids.tif b/tests/data/agrifieldnet/test_labels/ref_agrifieldnet_competition_v1_labels_test_ff450_field_ids.tif new file mode 100644 index 00000000000..989159c514d Binary files /dev/null and b/tests/data/agrifieldnet/test_labels/ref_agrifieldnet_competition_v1_labels_test_ff450_field_ids.tif differ diff --git a/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_32407.tif b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_32407.tif new file mode 100644 index 00000000000..02ec404531d Binary files /dev/null and b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_32407.tif differ diff --git a/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_32407_field_ids.tif b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_32407_field_ids.tif new file mode 100644 index 00000000000..438abb97986 Binary files /dev/null and b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_32407_field_ids.tif differ diff --git a/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_8641e.tif b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_8641e.tif new file mode 100644 index 00000000000..6eaa8bc18b5 Binary files /dev/null and b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_8641e.tif differ diff --git a/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_8641e_field_ids.tif b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_8641e_field_ids.tif new file mode 100644 index 00000000000..2433430a853 Binary files /dev/null and b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_8641e_field_ids.tif differ diff --git a/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_a419f.tif b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_a419f.tif new file mode 100644 index 00000000000..4e0e2631d07 Binary files /dev/null and b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_a419f.tif differ diff --git a/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_a419f_field_ids.tif b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_a419f_field_ids.tif new file mode 100644 index 00000000000..01b9858f309 Binary files /dev/null and b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_a419f_field_ids.tif differ diff --git a/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_eac11.tif b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_eac11.tif new file mode 100644 index 00000000000..e40ca4f2d94 Binary files /dev/null and b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_eac11.tif differ diff --git a/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_eac11_field_ids.tif b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_eac11_field_ids.tif new file mode 100644 index 00000000000..a33060fef4a Binary files /dev/null and b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_eac11_field_ids.tif differ diff --git a/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_ff450.tif b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_ff450.tif new file mode 100644 index 00000000000..a55c23fa29a Binary files /dev/null and b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_ff450.tif differ diff --git a/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_ff450_field_ids.tif b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_ff450_field_ids.tif new file mode 100644 index 00000000000..bb8019f6357 Binary files /dev/null and b/tests/data/agrifieldnet/train_labels/ref_agrifieldnet_competition_v1_labels_train_ff450_field_ids.tif differ diff --git a/tests/data/nccm/13090442.zip b/tests/data/nccm/13090442.zip deleted file mode 100644 index 19d0792078a..00000000000 Binary files a/tests/data/nccm/13090442.zip and /dev/null differ diff --git a/tests/data/nccm/13090442/CDL2017_clip.tif b/tests/data/nccm/13090442/CDL2017_clip.tif deleted file mode 100644 index 8dce2bb82e9..00000000000 Binary files a/tests/data/nccm/13090442/CDL2017_clip.tif and /dev/null differ diff --git a/tests/data/nccm/13090442/CDL2018_clip1.tif b/tests/data/nccm/13090442/CDL2018_clip1.tif deleted file mode 100644 index 531cd5f4f1f..00000000000 Binary files a/tests/data/nccm/13090442/CDL2018_clip1.tif and /dev/null differ diff --git a/tests/data/nccm/13090442/CDL2019_clip.tif b/tests/data/nccm/13090442/CDL2019_clip.tif deleted file mode 100644 index 67be3087ed7..00000000000 Binary files a/tests/data/nccm/13090442/CDL2019_clip.tif and /dev/null differ diff --git a/tests/data/nccm/CDL2017_clip.tif b/tests/data/nccm/CDL2017_clip.tif new file mode 100644 index 00000000000..1040f7936c6 Binary files /dev/null and b/tests/data/nccm/CDL2017_clip.tif differ diff --git a/tests/data/nccm/CDL2018_clip1.tif b/tests/data/nccm/CDL2018_clip1.tif new file mode 100644 index 00000000000..3313fef10d1 Binary files /dev/null and b/tests/data/nccm/CDL2018_clip1.tif differ diff --git a/tests/data/nccm/CDL2019_clip.tif b/tests/data/nccm/CDL2019_clip.tif new file mode 100644 index 00000000000..9c4d1dcae44 Binary files /dev/null and b/tests/data/nccm/CDL2019_clip.tif differ diff --git a/tests/data/nccm/data.py b/tests/data/nccm/data.py index 6a98ca3a2d0..2956f147033 100644 --- a/tests/data/nccm/data.py +++ b/tests/data/nccm/data.py @@ -5,7 +5,6 @@ import hashlib import os -import shutil import numpy as np import rasterio @@ -48,20 +47,14 @@ def create_file(path: str, dtype: str): if __name__ == "__main__": - dir = os.path.join(os.getcwd(), "13090442") - - if os.path.exists(dir) and os.path.isdir(dir): - shutil.rmtree(dir) - + dir = os.path.join(os.getcwd()) os.makedirs(dir, exist_ok=True) for file in files: create_file(os.path.join(dir, file), dtype="int8") - # Compress data - shutil.make_archive("13090442", "zip", ".", dir) - # Compute checksums - with open("13090442.zip", "rb") as f: - md5 = hashlib.md5(f.read()).hexdigest() - print(f"13090442.zip: {md5}") + for file in files: + with open(file, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{file}: {md5}") diff --git a/tests/datasets/test_agrifieldnet.py b/tests/datasets/test_agrifieldnet.py new file mode 100644 index 00000000000..cdb539671f9 --- /dev/null +++ b/tests/datasets/test_agrifieldnet.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from rasterio.crs import CRS + +from torchgeo.datasets import ( + AgriFieldNet, + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + RGBBandsMissingError, + UnionDataset, +) + + +class TestAgriFieldNet: + @pytest.fixture + def dataset(self) -> AgriFieldNet: + path = os.path.join("tests", "data", "agrifieldnet") + transforms = nn.Identity() + return AgriFieldNet(paths=path, transforms=transforms) + + def test_getitem(self, dataset: AgriFieldNet) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["crs"], CRS) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["mask"], torch.Tensor) + + def test_and(self, dataset: AgriFieldNet) -> None: + ds = dataset & dataset + assert isinstance(ds, IntersectionDataset) + + def test_or(self, dataset: AgriFieldNet) -> None: + ds = dataset | dataset + assert isinstance(ds, UnionDataset) + + def test_already_downloaded(self, dataset: AgriFieldNet) -> None: + AgriFieldNet(paths=dataset.paths) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + AgriFieldNet(str(tmp_path)) + + def test_plot(self, dataset: AgriFieldNet) -> None: + x = dataset[dataset.bounds] + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: AgriFieldNet) -> None: + x = dataset[dataset.bounds] + x["prediction"] = x["mask"].clone() + dataset.plot(x, suptitle="Prediction") + plt.close() + + def test_invalid_query(self, dataset: AgriFieldNet) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* not found in index with bounds:" + ): + dataset[query] + + def test_rgb_bands_absent_plot(self, dataset: AgriFieldNet) -> None: + with pytest.raises( + RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + ): + ds = AgriFieldNet(dataset.paths, bands=["B01", "B02", "B05"]) + x = ds[ds.bounds] + ds.plot(x, suptitle="Test") + plt.close() diff --git a/tests/datasets/test_nccm.py b/tests/datasets/test_nccm.py index 6637da3e840..0d922d9d3d5 100644 --- a/tests/datasets/test_nccm.py +++ b/tests/datasets/test_nccm.py @@ -25,9 +25,19 @@ class TestNCCM: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NCCM: monkeypatch.setattr(torchgeo.datasets.nccm, "download_url", download_url) - url = os.path.join("tests", "data", "nccm", "13090442.zip") + md5s = { + 2017: "ae5c390d0ffb8970d544b8a09142759f", + 2018: "0d453bdb8ea5b7318c33e62513760580", + 2019: "d4ab7ab00bb57623eafb6b27747e5639", + } + monkeypatch.setattr(NCCM, "md5s", md5s) + urls = { + 2017: os.path.join("tests", "data", "nccm", "CDL2017_clip.tif"), + 2018: os.path.join("tests", "data", "nccm", "CDL2018_clip1.tif"), + 2019: os.path.join("tests", "data", "nccm", "CDL2019_clip.tif"), + } + monkeypatch.setattr(NCCM, "urls", urls) transforms = nn.Identity() - monkeypatch.setattr(NCCM, "url", url) root = str(tmp_path) return NCCM(root, transforms=transforms, download=True, checksum=True) @@ -48,11 +58,8 @@ def test_or(self, dataset: NCCM) -> None: def test_already_extracted(self, dataset: NCCM) -> None: NCCM(dataset.paths, download=True) - def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "nccm", "13090442.zip") - root = str(tmp_path) - shutil.copy(pathname, root) - NCCM(root) + def test_already_downloaded(self, dataset: NCCM) -> None: + NCCM(dataset.paths, download=True) def test_plot(self, dataset: NCCM) -> None: query = dataset.bounds diff --git a/tests/trainers/test_utils.py b/tests/trainers/test_utils.py index 52d7a9be25d..06da0a359eb 100644 --- a/tests/trainers/test_utils.py +++ b/tests/trainers/test_utils.py @@ -41,7 +41,7 @@ def test_get_input_layer_name_and_module() -> None: def test_load_state_dict(checkpoint: str, model: Module) -> None: _, state_dict = extract_backbone(checkpoint) - model = load_state_dict(model, state_dict) + load_state_dict(model, state_dict) def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) -> None: @@ -58,7 +58,7 @@ def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) f" model {expected_in_channels}. Overriding with new input channels" ) with pytest.warns(UserWarning, match=warning): - model = load_state_dict(model, state_dict) + load_state_dict(model, state_dict) def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None: @@ -74,7 +74,7 @@ def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None f" {expected_num_classes}. Overriding with new num classes" ) with pytest.warns(UserWarning, match=warning): - model = load_state_dict(model, state_dict) + load_state_dict(model, state_dict) def test_reinit_initial_conv_layer() -> None: diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 912ccb58db5..604c0d7e3ec 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -6,7 +6,6 @@ from typing import Any, Optional import kornia.augmentation as K -import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange @@ -113,7 +112,7 @@ def __init__( self.test_splits = test_splits self.class_set = class_set self.use_prior_labels = use_prior_labels - self.prior_smoothing_constant = torch.tensor(prior_smoothing_constant) + self.prior_smoothing_constant = prior_smoothing_constant if self.use_prior_labels: self.layers = [ diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index ed4fbe15795..3a3b7531cba 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -6,7 +6,6 @@ from typing import Any import kornia.augmentation as K -import torch from torch import Tensor from ..datasets import SpaceNet1 @@ -88,6 +87,6 @@ def on_after_batch_transfer( # We add 1 to the mask to map the current {background, building} labels to # the values {1, 2}. This is necessary because we add 0 padding to the # mask that we want to ignore in the loss function. - batch["mask"] += torch.tensor(1) + batch["mask"] += 1 return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 235ab83c8eb..71eb3f168dc 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -5,6 +5,7 @@ from .advance import ADVANCE from .agb_live_woody_density import AbovegroundLiveWoodyBiomassDensity +from .agrifieldnet import AgriFieldNet from .airphen import Airphen from .astergdem import AsterGDEM from .benin_cashews import BeninSmallHolderCashews @@ -139,6 +140,7 @@ __all__ = ( # GeoDataset "AbovegroundLiveWoodyBiomassDensity", + "AgriFieldNet", "Airphen", "AsterGDEM", "CanadianBuildingFootprints", diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py new file mode 100644 index 00000000000..67a0c994f44 --- /dev/null +++ b/torchgeo/datasets/agrifieldnet.py @@ -0,0 +1,272 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""AgriFieldNet India Challenge dataset.""" + +import os +import re +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Optional, Union, cast + +import matplotlib.pyplot as plt +import torch +from matplotlib.figure import Figure +from rasterio.crs import CRS +from torch import Tensor + +from .geo import RasterDataset +from .utils import BoundingBox, RGBBandsMissingError + + +class AgriFieldNet(RasterDataset): + """AgriFieldNet India Challenge dataset. + + The `AgriFieldNet India Challenge + `__ dataset + includes satellite imagery from Sentinel-2 cloud free composites + (single snapshot) and labels for crop type that were collected by ground survey. + The Sentinel-2 data are then matched with corresponding labels. + The dataset contains 7081 fields, which have been split into training and + test sets (5551 fields in the train and 1530 fields in the test). + Satellite imagery and labels are tiled into 256x256 chips adding up to 1217 tiles. + The fields are distributed across all chips, some chips may only have train or + test fields and some may have both. Since the labels are derived from data + collected on the ground, not all the pixels are labeled in each chip. + If the field ID for a pixel is set to 0 it means that pixel is not included in + either of the train or test set (and correspondingly the crop label + will be 0 as well). For this challenge train and test sets have slightly + different crop type distributions. The train set follows the distribution + of ground reference data which is a skewed distribution with a few dominant + crops being over represented. The test set was drawn randomly from an area + weighted field list that ensured that fields with less common crop types + were better represented in the test set. The original dataset can be + downloaded from `Source Cooperative `__. + + Dataset format: + + * images are 12-band Sentinel-2 data + * masks are tiff images with unique values representing the class and field id + + Dataset classes: + + 0 - No-Data + 1 - Wheat + 2 - Mustard + 3 - Lentil + 4 - No Crop/Fallow + 5 - Green pea + 6 - Sugarcane + 8 - Garlic + 9 - Maize + 13 - Gram + 14 - Coriander + 15 - Potato + 16 - Berseem + 36 - Rice + + If you use this dataset in your research, please cite the following dataset: + + * https://doi.org/10.34911/rdnt.wu92p1 + + .. versionadded:: 0.6 + """ + + filename_regex = r""" + ^ref_agrifieldnet_competition_v1_source_ + (?P[a-z0-9]{5}) + _(?PB[0-9A-Z]{2})_10m + """ + + rgb_bands = ["B04", "B03", "B02"] + all_bands = [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B11", + "B12", + ] + + cmap = { + 0: (0, 0, 0, 255), + 1: (255, 211, 0, 255), + 2: (255, 37, 37, 255), + 3: (0, 168, 226, 255), + 4: (255, 158, 9, 255), + 5: (37, 111, 0, 255), + 6: (255, 255, 0, 255), + 8: (111, 166, 0, 255), + 9: (0, 175, 73, 255), + 13: (222, 166, 9, 255), + 14: (222, 166, 9, 255), + 15: (124, 211, 255, 255), + 16: (226, 0, 124, 255), + 36: (137, 96, 83, 255), + } + + def __init__( + self, + paths: Union[str, Iterable[str]] = "data", + crs: Optional[CRS] = None, + classes: list[int] = list(cmap.keys()), + bands: Sequence[str] = all_bands, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + cache: bool = True, + ) -> None: + """Initialize a new AgriFieldNet dataset instance. + + Args: + paths: one or more root directories to search for files to load + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + classes: list of classes to include, the rest will be mapped to 0 + (defaults to all classes) + bands: the subset of bands to load + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + cache: if True, cache the dataset in memory + + Raises: + DatasetNotFoundError: If dataset is not found. + """ + assert ( + set(classes) <= self.cmap.keys() + ), f"Only the following classes are valid: {list(self.cmap.keys())}." + assert 0 in classes, "Classes must include the background class: 0" + + self.paths = paths + self.classes = classes + self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype) + self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8) + + super().__init__( + paths=paths, crs=crs, bands=bands, transforms=transforms, cache=cache + ) + + # Map chosen classes to ordinal numbers, all others mapped to background class + for v, k in enumerate(self.classes): + self.ordinal_map[k] = v + self.ordinal_cmap[v] = torch.tensor(self.cmap[k]) + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data, label, and field ids at that index + """ + assert isinstance(self.paths, str) + + hits = self.index.intersection(tuple(query), objects=True) + filepaths = cast(list[str], [hit.object for hit in hits]) + + if not filepaths: + raise IndexError( + f"query: {query} not found in index with bounds: {self.bounds}" + ) + + data_list: list[Tensor] = [] + filename_regex = re.compile(self.filename_regex, re.VERBOSE) + for band in self.bands: + band_filepaths = [] + for filepath in filepaths: + filename = os.path.basename(filepath) + directory = os.path.dirname(filepath) + match = re.match(filename_regex, filename) + if match: + if "band" in match.groupdict(): + start = match.start("band") + end = match.end("band") + filename = filename[:start] + band + filename[end:] + filepath = os.path.join(directory, filename) + band_filepaths.append(filepath) + data_list.append(self._merge_files(band_filepaths, query)) + image = torch.cat(data_list) + + mask_filepaths = [] + for root, dirs, files in os.walk(os.path.join(self.paths, "train_labels")): + for file in files: + if not file.endswith("_field_ids.tif") and file.endswith(".tif"): + file_path = os.path.join(root, file) + mask_filepaths.append(file_path) + + mask = self._merge_files(mask_filepaths, query) + mask = self.ordinal_map[mask.squeeze().long()] + + sample = { + "crs": self.crs, + "bbox": query, + "image": image.float(), + "mask": mask.long(), + } + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + Raises: + RGBBandsMissingError: If *bands* does not include all RGB bands. + """ + rgb_indices = [] + for band in self.rgb_bands: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise RGBBandsMissingError() + + image = sample["image"][rgb_indices].permute(1, 2, 0) + image = (image - image.min()) / (image.max() - image.min()) + + mask = sample["mask"].squeeze() + ncols = 2 + + showing_prediction = "prediction" in sample + if showing_prediction: + pred = sample["prediction"].squeeze() + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4)) + axs[0].imshow(image) + axs[0].axis("off") + axs[1].imshow(self.ordinal_cmap[mask], interpolation="none") + axs[1].axis("off") + if show_titles: + axs[0].set_title("Image") + axs[1].set_title("Mask") + + if showing_prediction: + axs[2].imshow(self.ordinal_cmap[pred], interpolation="none") + axs[2].axis("off") + if show_titles: + axs[2].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 9edda8d26cb..6b9f95d34bb 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -387,8 +387,7 @@ def _load_mask(self, transform: rasterio.Affine) -> Tensor: dtype=np.uint8, ) - mask = torch.from_numpy(mask_data) - mask = mask.long() + mask = torch.from_numpy(mask_data).long() return mask def _check_integrity(self) -> bool: diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 3a821811f54..9a127248a8b 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -408,8 +408,7 @@ def _load_image(self, index: int) -> Tensor: ) images.append(array) arrays: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0) - tensor = torch.from_numpy(arrays) - tensor = tensor.float() + tensor = torch.from_numpy(arrays).float() return tensor def _load_target(self, index: int) -> Tensor: diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index eff773308ae..970c5594950 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -198,8 +198,7 @@ def _load_target(self, filename: str) -> Tensor: with rasterio.open(os.path.join(self.root, "train_agbm", filename), "r") as src: arr: "np.typing.NDArray[np.float_]" = src.read() - target = torch.from_numpy(arr) - target = target.float() + target = torch.from_numpy(arr).float() return target def _verify(self) -> None: diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py index 5f19a935beb..60bbe78e21d 100644 --- a/torchgeo/datasets/cloud_cover.py +++ b/torchgeo/datasets/cloud_cover.py @@ -384,7 +384,7 @@ def plot( else: n_cols = 2 - image, mask = sample["image"] / torch.tensor(3000), sample["mask"] + image, mask = sample["image"] / 3000, sample["mask"] fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5)) diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index 47cd3766af6..0e0518502ad 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -144,8 +144,7 @@ def _load_image(self, index: int) -> Tensor: filename = os.path.join(self.root, self.images[index]) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array) - tensor = tensor.float() + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor @@ -160,8 +159,7 @@ def _load_target(self, index: int) -> Tensor: the target """ target = int(self.targets[index]) - tensor = torch.tensor(target) - tensor = tensor.float() + tensor = torch.tensor(target).float() return tensor def _check_integrity(self) -> bool: diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 49cf247e409..c9a1243e970 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -164,8 +164,7 @@ def _load_image(self, directory: str) -> Tensor: img = img.resize(size=(self.size, self.size), resample=resample) array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) tensor = torch.from_numpy(array) - tensor = tensor.permute((2, 0, 1)) - tensor = tensor.float() + tensor = tensor.permute((2, 0, 1)).float() return tensor def _load_features(self, directory: str) -> dict[str, Any]: diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index e93a35fa9af..7dfa50fb2ab 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -204,8 +204,7 @@ def _load_image(self, path: str) -> Tensor: filename = os.path.join(path) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor = torch.from_numpy(array) - tensor = tensor.float() + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 9253cadbafd..715c3bfedab 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -494,9 +494,7 @@ def plot( assert len(hsi_indices) == 3 def normalize(x: Tensor) -> Tensor: - # https://github.com/pytorch/pytorch/issues/116327 - out: Tensor = (x - x.min()) / (x.max() - x.min()) - return out + return (x - x.min()) / (x.max() - x.min()) ncols = 3 diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index b3b8c79d57a..b3ab0a6fd9c 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -135,8 +135,7 @@ def _load_image(self, path: str) -> Tensor: """ with rio.open(path) as img: array = img.read().astype(np.int32) - tensor = torch.from_numpy(array) - tensor = tensor.float() + tensor = torch.from_numpy(array).float() return tensor def _load_target(self, path: str) -> Tensor: @@ -151,8 +150,7 @@ def _load_target(self, path: str) -> Tensor: with rio.open(path) as img: array = img.read().astype(np.int32) array = np.clip(array, 0, 1) - mask = torch.from_numpy(array[0]) - mask = mask.long() + mask = torch.from_numpy(array[0]).long() return mask def __len__(self) -> int: diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index 1962bdcbca5..8dec50b562e 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -364,8 +364,7 @@ def _load_image(self, id_: str) -> Tensor: filename = os.path.join(self.root, "output", id_ + ".jpg") with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array) - tensor = tensor.float() + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor @@ -383,8 +382,7 @@ def _load_target(self, id_: str) -> Tensor: filename = os.path.join(self.root, "output", id_ + "_m.png") with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L")) - tensor = torch.from_numpy(array) - tensor = tensor.long() + tensor = torch.from_numpy(array).long() return tensor def _verify_data(self) -> bool: diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 0642557d187..76481fed1dd 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -109,8 +109,7 @@ def _load_image(self, path: str) -> Tensor: filename = os.path.join(path) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor = torch.from_numpy(array) - tensor = tensor.float() + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index fc5b15af20a..9c7e2aaff4e 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -210,8 +210,7 @@ def _load_image(self, path: str) -> Tensor: filename = os.path.join(path) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor = torch.from_numpy(array) - tensor = tensor.float() + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py index f26c2564842..5eaa426d230 100644 --- a/torchgeo/datasets/mapinwild.py +++ b/torchgeo/datasets/mapinwild.py @@ -220,8 +220,7 @@ def _load_raster(self, filename: int, source: str) -> Tensor: array: "np.typing.NDArray[np.int_]" = np.stack(raw_array, axis=0) if array.dtype == np.uint16: array = array.astype(np.int32) - tensor = torch.from_numpy(array) - tensor = tensor.float() + tensor = torch.from_numpy(array).float() return tensor def _verify(self, url: str, md5: Optional[str] = None) -> None: diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index b43b1881f1e..a1637f46e7d 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -138,8 +138,7 @@ def _load_image(self, path: str) -> Tensor: """ with rasterio.open(path) as f: array = f.read() - tensor = torch.from_numpy(array) - tensor = tensor.float() + tensor = torch.from_numpy(array).float() return tensor def _load_target(self, path: str) -> Tensor: diff --git a/torchgeo/datasets/nccm.py b/torchgeo/datasets/nccm.py index 3a43ddddcc5..38a0d3eee91 100644 --- a/torchgeo/datasets/nccm.py +++ b/torchgeo/datasets/nccm.py @@ -3,8 +3,6 @@ """Northeastern China Crop Map Dataset.""" -import glob -import os from collections.abc import Iterable from typing import Any, Callable, Optional, Union @@ -14,7 +12,7 @@ from rasterio.crs import CRS from .geo import RasterDataset -from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive +from .utils import BoundingBox, DatasetNotFoundError, download_url class NCCM(RasterDataset): @@ -55,12 +53,24 @@ class NCCM(RasterDataset): filename_regex = r"CDL(?P\d{4})_clip" filename_glob = "CDL*.*" - zipfile_glob = "13090442.zip" date_format = "%Y" is_image = False - url = "https://figshare.com/ndownloader/articles/13090442/versions/1" - md5 = "eae952f1b346d7e649d027e8139a76f5" + urls = { + 2019: "https://figshare.com/ndownloader/files/25070540", + 2018: "https://figshare.com/ndownloader/files/25070624", + 2017: "https://figshare.com/ndownloader/files/25070582", + } + md5s = { + 2019: "0d062bbd42e483fdc8239d22dba7020f", + 2018: "b3bb4894478d10786aa798fb11693ec1", + 2017: "d047fbe4a85341fa6248fd7e0badab6c", + } + fnames = { + 2019: "CDL2019_clip.tif", + 2018: "CDL2018_clip1.tif", + 2017: "CDL2017_clip.tif", + } cmap = { 0: (0, 255, 0, 255), @@ -75,6 +85,7 @@ def __init__( paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: Optional[float] = None, + years: list[int] = [2019], transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, cache: bool = True, download: bool = False, @@ -88,6 +99,7 @@ def __init__( (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS (defaults to the resolution of the first file found) + years: list of years for which to use nccm layers transforms: a function/transform that takes an input sample and returns a transformed version cache: if True, cache file handle to speed up repeated sampling @@ -97,7 +109,12 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ + assert set(years) <= self.md5s.keys(), ( + "NCCM data product only exists for the following years: " + f"{list(self.md5s.keys())}." + ) self.paths = paths + self.years = years self.download = download self.checksum = checksum self.ordinal_map = torch.full((max(self.cmap.keys()) + 1,), 4, dtype=self.dtype) @@ -128,37 +145,26 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: def _verify(self) -> None: """Verify the integrity of the dataset.""" - # Check if the extracted files already exist + # Check if the files already exist if self.files: return - # Check if the zip file has already been downloaded - assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "**", self.zipfile_glob) - if glob.glob(pathname, recursive=True): - self._extract() - return - # Check if the user requested to download the dataset if not self.download: raise DatasetNotFoundError(self) # Download the dataset self._download() - self._extract() def _download(self) -> None: """Download the dataset.""" - filename = "13090442.zip" - download_url( - self.url, self.paths, filename, md5=self.md5 if self.checksum else None - ) - - def _extract(self) -> None: - """Extract the dataset.""" - assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "**", self.zipfile_glob) - extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) + for year in self.years: + download_url( + self.urls[year], + self.paths, + filename=self.fnames[year], + md5=self.md5s[year] if self.checksum else None, + ) def plot( self, diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index db98e6ff2b4..eebe7348d0a 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -222,8 +222,7 @@ def _load_image(self, paths: Sequence[str]) -> Tensor: with Image.open(path) as img: images.append(np.array(img)) array: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0).astype(np.int_) - tensor = torch.from_numpy(array) - tensor = tensor.float() + tensor = torch.from_numpy(array).float() return tensor def _load_target(self, path: str) -> Tensor: diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py index 4e38eb4af30..84925a85022 100644 --- a/torchgeo/datasets/pastis.py +++ b/torchgeo/datasets/pastis.py @@ -235,8 +235,7 @@ def _load_semantic_targets(self, index: int) -> Tensor: # See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501 # even though the mask file is 3 bands, we just select the first band array = np.load(self.files[index]["semantic"])[0].astype(np.uint8) - tensor = torch.from_numpy(array) - tensor = tensor.long() + tensor = torch.from_numpy(array).long() return tensor def _load_instance_targets(self, index: int) -> tuple[Tensor, Tensor, Tensor]: diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py index a510a3f49fb..782fd7ce87d 100644 --- a/torchgeo/datasets/potsdam.py +++ b/torchgeo/datasets/potsdam.py @@ -192,8 +192,7 @@ def _load_image(self, index: int) -> Tensor: path = self.files[index]["image"] with rasterio.open(path) as f: array = f.read() - tensor = torch.from_numpy(array) - tensor = tensor.float() + tensor = torch.from_numpy(array).float() return tensor def _load_target(self, index: int) -> Tensor: diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py index 82e0438f2d0..b01ad3cae68 100644 --- a/torchgeo/datasets/seasonet.py +++ b/torchgeo/datasets/seasonet.py @@ -359,9 +359,7 @@ def _load_target(self, index: int) -> Tensor: path = self.files.iloc[index][0] with rasterio.open(f"{path}_labels.tif") as f: array = f.read() - 1 - tensor = torch.from_numpy(array) - tensor = tensor.squeeze() - tensor = tensor.long() + tensor = torch.from_numpy(array).squeeze().long() return tensor def _verify(self) -> None: diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index baa21c4495c..156b3f1568d 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -173,8 +173,7 @@ def _load_image(self, index: int) -> Tensor: else: arr = rearrange(arr, "h w c -> c h w") - tensor = torch.from_numpy(arr) - tensor = tensor.to(torch.float32) + tensor = torch.from_numpy(arr).to(torch.float32) return tensor def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]: diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 5cb5d1e2356..c6780e1971c 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -200,8 +200,7 @@ def _load_mask( dtype=np.uint8, ) - mask = torch.from_numpy(mask_data) - mask = mask.long() + mask = torch.from_numpy(mask_data).long() return mask @@ -727,8 +726,7 @@ def _load_mask( dtype=np.uint8, ) - mask = torch.from_numpy(mask_data) - mask = mask.long() + mask = torch.from_numpy(mask_data).long() return mask def plot( diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index a3f769aa789..003cccf2fcf 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -306,8 +306,7 @@ def _load_image(self, path: str) -> Tensor: image """ with rasterio.open(path) as src: - image = torch.from_numpy(src.read()) - image = image.float() + image = torch.from_numpy(src.read()).float() return image def _load_mask(self, path: str) -> Tensor: @@ -320,8 +319,7 @@ def _load_mask(self, path: str) -> Tensor: mask """ with rasterio.open(path) as src: - mask = torch.from_numpy(src.read()) - mask = mask.long() + mask = torch.from_numpy(src.read()).long() mask = self.ordinal_map[mask] return mask diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 81823cd848d..f33d268d4ce 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -181,8 +181,7 @@ def _load_image(self, path: str) -> Tensor: """ with rasterio.open(path) as f: array: "np.typing.NDArray[np.int_]" = f.read() - tensor = torch.from_numpy(array) - tensor = tensor.float() + tensor = torch.from_numpy(array).float() return tensor def _verify(self) -> None: diff --git a/torchgeo/losses/qr.py b/torchgeo/losses/qr.py index 9bffc4a915f..ecffa33e9fa 100644 --- a/torchgeo/losses/qr.py +++ b/torchgeo/losses/qr.py @@ -5,7 +5,6 @@ import torch import torch.nn.functional as F -from torch import Tensor from torch.nn.modules import Module @@ -29,16 +28,12 @@ def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: qr loss """ q = probs - # https://github.com/pytorch/pytorch/issues/116327 - q_bar: Tensor = q.mean(dim=(0, 2, 3)) - log_q_bar = torch.log(q_bar) - qbar_log_S: Tensor = q_bar * log_q_bar - qbar_log_S = qbar_log_S.sum() + q_bar = q.mean(dim=(0, 2, 3)) + qbar_log_S = (q_bar * torch.log(q_bar)).sum() - q_log_p = torch.einsum("bcxy,bcxy->bxy", q, torch.log(target)) - q_log_p = q_log_p.mean() + q_log_p = torch.einsum("bcxy,bcxy->bxy", q, torch.log(target)).mean() - loss: Tensor = qbar_log_S - q_log_p + loss = qbar_log_S - q_log_p return loss @@ -67,7 +62,6 @@ def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: z = q / q.norm(p=1, dim=(0, 2, 3), keepdim=True).clamp_min(1e-12).expand_as(q) r = F.normalize(z * target, p=1, dim=1) - loss = torch.einsum("bcxy,bcxy->bxy", r, torch.log(r) - torch.log(q)) - loss = loss.mean() + loss = torch.einsum("bcxy,bcxy->bxy", r, torch.log(r) - torch.log(q)).mean() return loss diff --git a/torchgeo/models/rcf.py b/torchgeo/models/rcf.py index ebf46bcc610..59f42223cf1 100644 --- a/torchgeo/models/rcf.py +++ b/torchgeo/models/rcf.py @@ -94,7 +94,7 @@ def __init__( ), ) self.register_buffer( - "biases", torch.zeros(num_patches, requires_grad=False) + torch.tensor(bias) + "biases", torch.zeros(num_patches, requires_grad=False) + bias ) if mode == "empirical": diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index a60a787444a..ef9bb05f10f 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -128,7 +128,7 @@ def __init__( # torch.multinomial requires float probabilities > 0 self.areas = torch.tensor(areas, dtype=torch.float) if torch.sum(self.areas) == 0: - self.areas += torch.tensor(1) + self.areas += 1 def __iter__(self) -> Iterator[list[BoundingBox]]: """Return the indices of a dataset. diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 1180044c6e4..7251f4b274f 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -128,7 +128,7 @@ def __init__( # torch.multinomial requires float probabilities > 0 self.areas = torch.tensor(areas, dtype=torch.float) if torch.sum(self.areas) == 0: - self.areas += torch.tensor(1) + self.areas += 1 def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 68bdb6c9c43..d6c0b62765e 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -343,7 +343,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - backbone = utils.load_state_dict(backbone, state_dict) + utils.load_state_dict(backbone, state_dict) self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224)) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 9ac312051c7..5d8d10c9dce 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -137,7 +137,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.model = utils.load_state_dict(self.model, state_dict) + utils.load_state_dict(self.model, state_dict) # Freeze backbone and unfreeze classifier head if self.hparams["freeze_backbone"]: diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index d2621a8da74..4dbc1e453c9 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -261,7 +261,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.backbone = utils.load_state_dict(self.backbone, state_dict) + utils.load_state_dict(self.backbone, state_dict) # Create projection (and prediction) head batch_norm = version == 3 diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index c58f033b5bc..9cc2ea56441 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -128,7 +128,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.model = utils.load_state_dict(self.model, state_dict) + utils.load_state_dict(self.model, state_dict) # Freeze backbone and unfreeze classifier head if self.hparams["freeze_backbone"]: diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index a889be1c96f..27719eda224 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -172,7 +172,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.backbone = utils.load_state_dict(self.backbone, state_dict) + utils.load_state_dict(self.backbone, state_dict) # Create projection head input_dim = self.backbone.num_features diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index e1b6678ed73..b5cd8f1e923 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -71,7 +71,9 @@ def _get_input_layer_name_and_module(model: Module) -> tuple[str, Module]: return key, module -def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Module: +def load_state_dict( + model: Module, state_dict: "OrderedDict[str, Tensor]" +) -> tuple[list[str], list[str]]: """Load pretrained resnet weights to a model. Args: @@ -79,7 +81,7 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo state_dict: dict containing tensor parameters Returns: - the model with pretrained weights + The missing and unexpected keys Warns: If input channels in model != pretrained model input channels @@ -115,8 +117,10 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo state_dict[output_module_key + ".bias"], ) - model.load_state_dict(state_dict, strict=False) - return model + missing_keys: list[str] + unexpected_keys: list[str] + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + return missing_keys, unexpected_keys def reinit_initial_conv_layer( diff --git a/torchgeo/transforms/color.py b/torchgeo/transforms/color.py index e89ed471a4a..5459fc2f854 100644 --- a/torchgeo/transforms/color.py +++ b/torchgeo/transforms/color.py @@ -70,11 +70,8 @@ def apply_transform( Returns: The augmented input. """ - weights = flags["weights"] - weights = weights[..., :, None, None] - weights = weights.to(input.device) - out: Tensor = input * weights + weights = flags["weights"][..., :, None, None].to(input.device) + out = input * weights out = out.sum(dim=-3) - out = out.unsqueeze(-3) - out = out.expand(input.shape) + out = out.unsqueeze(-3).expand(input.shape) return out