From 264313f2ec3788c749363f4c77bd1d47301808c4 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 22 Apr 2024 10:39:33 +0200 Subject: [PATCH] cubic -> bilinear --- docs/tutorials/custom_raster_dataset.ipynb | 2 +- tests/datasets/test_geo.py | 2 +- torchgeo/datasets/geo.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb index 039db58b62e..e1091c7be4d 100644 --- a/docs/tutorials/custom_raster_dataset.ipynb +++ b/docs/tutorials/custom_raster_dataset.ipynb @@ -338,7 +338,7 @@ "\n", "### `resampling`\n", "\n", - "Defaults to cubic for float Tensors and nearest for int Tensors. Can be overridden for custom resampling algorithms.\n", + "Defaults to bilinear for float Tensors and nearest for int Tensors. Can be overridden for custom resampling algorithms.\n", "\n", "### `separate_files`\n", "\n", diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index c42f460a289..e870dd66af4 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -292,7 +292,7 @@ def test_resampling_float_dtype(self, dtype: torch.dtype) -> None: ds = CustomRasterDataset(dtype, paths) x = ds[ds.bounds] assert x["image"].dtype == dtype - assert ds.resampling == Resampling.cubic + assert ds.resampling == Resampling.bilinear @pytest.mark.parametrize("dtype", [torch.long, torch.bool]) def test_resampling_int_dtype(self, dtype: torch.dtype) -> None: diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 98b781058e8..72bde924cd9 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -389,7 +389,7 @@ def dtype(self) -> torch.dtype: def resampling(self) -> Resampling: """Resampling algorithm used when reading input files. - Defaults to cubic for float dtypes and nearest for int dtypes. + Defaults to bilinear for float dtypes and nearest for int dtypes. Returns: The resampling method to use. @@ -398,7 +398,7 @@ def resampling(self) -> Resampling: """ # Based on torch.is_floating_point if self.dtype in [torch.float64, torch.float32, torch.float16, torch.bfloat16]: - return Resampling.cubic + return Resampling.bilinear else: return Resampling.nearest