From f66621091d478d009368ef066066e4878873c5df Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Thu, 10 Oct 2024 09:01:06 +0300 Subject: [PATCH 1/4] Refactor generic image reader tests --- .../tests/reader_tests/test_generic_image.py | 554 ++++++++++-------- 1 file changed, 295 insertions(+), 259 deletions(-) diff --git a/satpy/tests/reader_tests/test_generic_image.py b/satpy/tests/reader_tests/test_generic_image.py index 40d7611eb4..21e2b5b09e 100644 --- a/satpy/tests/reader_tests/test_generic_image.py +++ b/satpy/tests/reader_tests/test_generic_image.py @@ -16,273 +16,309 @@ # satpy. If not, see . """Unittests for generic image reader.""" -import os -import unittest +import datetime as dt import dask.array as da import numpy as np import pytest import xarray as xr +from pyresample.geometry import AreaDefinition +from rasterio.errors import NotGeoreferencedWarning +from satpy import Scene +from satpy.readers.generic_image import GenericImageFileHandler from satpy.tests.utils import RANDOM_GEN, make_dataid +DATA_DATE = dt.datetime(2018, 1, 1) -class TestGenericImage(unittest.TestCase): - """Test generic image reader.""" - - def setUp(self): - """Create temporary images to test on.""" - import datetime as dt - import tempfile - - from pyresample.geometry import AreaDefinition - - from satpy.scene import Scene - - self.date = dt.datetime(2018, 1, 1) - - # Create area definition - pcs_id = "ETRS89 / LAEA Europe" - proj4_dict = "EPSG:3035" - self.x_size = 100 - self.y_size = 100 - area_extent = (2426378.0132, 1528101.2618, 6293974.6215, 5446513.5222) - self.area_def = AreaDefinition("geotiff_area", pcs_id, pcs_id, - proj4_dict, self.x_size, self.y_size, - area_extent) - - # Create datasets for L, LA, RGB and RGBA mode images - r__ = da.random.randint(0, 256, size=(self.y_size, self.x_size), - chunks=(50, 50)).astype(np.uint8) - g__ = da.random.randint(0, 256, size=(self.y_size, self.x_size), - chunks=(50, 50)).astype(np.uint8) - b__ = da.random.randint(0, 256, size=(self.y_size, self.x_size), - chunks=(50, 50)).astype(np.uint8) - a__ = 255 * np.ones((self.y_size, self.x_size), dtype=np.uint8) - a__[:10, :10] = 0 - a__ = da.from_array(a__, chunks=(50, 50)) - - r_nan__ = RANDOM_GEN.uniform(0., 1., size=(self.y_size, self.x_size)) - r_nan__[:10, :10] = np.nan - r_nan__ = da.from_array(r_nan__, chunks=(50, 50)) - - ds_l = xr.DataArray(da.stack([r__]), dims=("bands", "y", "x"), - attrs={"name": "test_l", - "start_time": self.date}) - ds_l["bands"] = ["L"] - ds_la = xr.DataArray(da.stack([r__, a__]), dims=("bands", "y", "x"), - attrs={"name": "test_la", - "start_time": self.date}) - ds_la["bands"] = ["L", "A"] - ds_rgb = xr.DataArray(da.stack([r__, g__, b__]), +X_SIZE = 100 +Y_SIZE = 100 +AREA_DEFINITION = AreaDefinition("geotiff_area", "ETRS89 / LAEA Europe", "ETRS89 / LAEA Europe", + "EPSG:3035", X_SIZE, Y_SIZE, + (2426378.0132, 1528101.2618, 6293974.6215, 5446513.5222)) + + +@pytest.fixture +def random_image_channel(): + """Create random data.""" + return da.random.randint(0, 256, size=(Y_SIZE, X_SIZE), chunks=(50, 50)).astype(np.uint8) + + +random_image_channel_l = random_image_channel +random_image_channel_r = random_image_channel +random_image_channel_g = random_image_channel +random_image_channel_b = random_image_channel + + +@pytest.fixture +def alpha_channel(): + """Create alpha channel with fully transparent and opaque areas.""" + a__ = 255 * np.ones((Y_SIZE, X_SIZE), dtype=np.uint8) + a__[:10, :10] = 0 + return da.from_array(a__, chunks=(50, 50)) + + +@pytest.fixture +def random_image_channel_with_nans(): + """Create random data and replace a portion of it with NaN values.""" + arr = RANDOM_GEN.uniform(0., 1., size=(Y_SIZE, X_SIZE)) + arr[:10, :10] = np.nan + return da.from_array(arr, chunks=(50, 50)) + + +@pytest.fixture +def test_image_l(tmp_path, random_image_channel_l): + """Create a test image with mode L.""" + dset = xr.DataArray(da.stack([random_image_channel_l]), dims=("bands", "y", "x"), + attrs={"name": "test_l", "start_time": DATA_DATE}) + dset["bands"] = ["L"] + fname = tmp_path / "test_l.png" + _save_image(dset, fname, "simple_image") + + return fname + + +@pytest.fixture +def test_image_l_nan(tmp_path, random_image_channel_with_nans): + """Create a test image with mode L where data has NaN values.""" + dset = xr.DataArray(da.stack([random_image_channel_with_nans]), dims=("bands", "y", "x"), + attrs={"name": "test_l_nan", "start_time": DATA_DATE}) + dset["bands"] = ["L"] + fname = tmp_path / "test_l_nan_nofillvalue.tif" + _save_image(dset, fname, "geotiff") + + return fname + + +@pytest.fixture +def test_image_l_nan_fill_value(tmp_path, random_image_channel_with_nans): + """Create a test image with mode L where data has NaN values and fill value is set.""" + dset = xr.DataArray(da.stack([random_image_channel_with_nans]), dims=("bands", "y", "x"), + attrs={"name": "test_l_nan", "start_time": DATA_DATE}) + dset["bands"] = ["L"] + fname = tmp_path / "test_l_nan_fillvalue.tif" + _save_image(dset, fname, "geotiff", fill_value=0) + + return fname + + +@pytest.fixture +def test_image_la(tmp_path, random_image_channel_l, alpha_channel): + """Create a test image with mode LA.""" + dset = xr.DataArray(da.stack([random_image_channel_l, alpha_channel]), + dims=("bands", "y", "x"), + attrs={"name": "test_la", "start_time": DATA_DATE}) + dset["bands"] = ["L", "A"] + fname = tmp_path / "20180101_0000_test_la.png" + _save_image(dset, fname, "simple_image") + + return fname + + +@pytest.fixture +def test_image_rgb(tmp_path, random_image_channel_r, random_image_channel_g, random_image_channel_b): + """Create a test image with mode RGB.""" + dset = xr.DataArray(da.stack([random_image_channel_r, random_image_channel_g, random_image_channel_b]), dims=("bands", "y", "x"), attrs={"name": "test_rgb", - "start_time": self.date}) - ds_rgb["bands"] = ["R", "G", "B"] - ds_rgba = xr.DataArray(da.stack([r__, g__, b__, a__]), - dims=("bands", "y", "x"), - attrs={"name": "test_rgba", - "start_time": self.date}) - ds_rgba["bands"] = ["R", "G", "B", "A"] - - ds_l_nan = xr.DataArray(da.stack([r_nan__]), - dims=("bands", "y", "x"), - attrs={"name": "test_l_nan", - "start_time": self.date}) - ds_l_nan["bands"] = ["L"] - - # Temp dir for the saved images - self.base_dir = tempfile.mkdtemp() - - # Put the datasets to Scene for easy saving - scn = Scene() - scn["l"] = ds_l - scn["l"].attrs["area"] = self.area_def - scn["la"] = ds_la - scn["la"].attrs["area"] = self.area_def - scn["rgb"] = ds_rgb - scn["rgb"].attrs["area"] = self.area_def - scn["rgba"] = ds_rgba - scn["rgba"].attrs["area"] = self.area_def - scn["l_nan"] = ds_l_nan - scn["l_nan"].attrs["area"] = self.area_def - - # Save the images. Two images in PNG and two in GeoTIFF - scn.save_dataset("l", os.path.join(self.base_dir, "test_l.png"), writer="simple_image") - scn.save_dataset("la", os.path.join(self.base_dir, "20180101_0000_test_la.png"), writer="simple_image") - scn.save_dataset("rgb", os.path.join(self.base_dir, "20180101_0000_test_rgb.tif"), writer="geotiff") - scn.save_dataset("rgba", os.path.join(self.base_dir, "test_rgba.tif"), writer="geotiff") - scn.save_dataset("l_nan", os.path.join(self.base_dir, "test_l_nan_fillvalue.tif"), - writer="geotiff", fill_value=0) - scn.save_dataset("l_nan", os.path.join(self.base_dir, "test_l_nan_nofillvalue.tif"), - writer="geotiff") - - self.scn = scn - - def tearDown(self): - """Remove the temporary directory created for a test.""" - try: - import shutil - shutil.rmtree(self.base_dir, ignore_errors=True) - except OSError: - pass - - def test_png_scene(self): - """Test reading PNG images via satpy.Scene().""" - from rasterio.errors import NotGeoreferencedWarning - - from satpy import Scene - - fname = os.path.join(self.base_dir, "test_l.png") - with pytest.warns(NotGeoreferencedWarning, match=r"Dataset has no geotransform"): - scn = Scene(reader="generic_image", filenames=[fname]) - scn.load(["image"]) - assert scn["image"].shape == (1, self.y_size, self.x_size) - assert scn.sensor_names == {"images"} - assert scn.start_time is None - assert scn.end_time is None - assert "area" not in scn["image"].attrs - - fname = os.path.join(self.base_dir, "20180101_0000_test_la.png") - with pytest.warns(NotGeoreferencedWarning, match=r"Dataset has no geotransform"): - scn = Scene(reader="generic_image", filenames=[fname]) - scn.load(["image"]) - data = da.compute(scn["image"].data) - assert scn["image"].shape == (1, self.y_size, self.x_size) - assert scn.sensor_names == {"images"} - assert scn.start_time == self.date - assert scn.end_time == self.date - assert "area" not in scn["image"].attrs - assert np.sum(np.isnan(data)) == 100 - - def test_geotiff_scene(self): - """Test reading TIFF images via satpy.Scene().""" - from satpy import Scene - - fname = os.path.join(self.base_dir, "20180101_0000_test_rgb.tif") - scn = Scene(reader="generic_image", filenames=[fname]) - scn.load(["image"]) - assert scn["image"].shape == (3, self.y_size, self.x_size) - assert scn.sensor_names == {"images"} - assert scn.start_time == self.date - assert scn.end_time == self.date - assert scn["image"].area == self.area_def - - fname = os.path.join(self.base_dir, "test_rgba.tif") - scn = Scene(reader="generic_image", filenames=[fname]) - scn.load(["image"]) - assert scn["image"].shape == (3, self.y_size, self.x_size) - assert scn.sensor_names == {"images"} - assert scn.start_time is None - assert scn.end_time is None - assert scn["image"].area == self.area_def - - def test_geotiff_scene_nan(self): - """Test reading TIFF images originally containing NaN values via satpy.Scene().""" - from satpy import Scene - - fname = os.path.join(self.base_dir, "test_l_nan_fillvalue.tif") - scn = Scene(reader="generic_image", filenames=[fname]) - scn.load(["image"]) - assert scn["image"].shape == (1, self.y_size, self.x_size) - assert np.sum(scn["image"].data[0][:10, :10].compute()) == 0 - - fname = os.path.join(self.base_dir, "test_l_nan_nofillvalue.tif") - scn = Scene(reader="generic_image", filenames=[fname]) - scn.load(["image"]) - assert scn["image"].shape == (1, self.y_size, self.x_size) - assert np.all(np.isnan(scn["image"].data[0][:10, :10].compute())) - - def test_GenericImageFileHandler(self): - """Test direct use of the reader.""" - from satpy.readers.generic_image import GenericImageFileHandler - - fname = os.path.join(self.base_dir, "test_rgba.tif") - fname_info = {"start_time": self.date} - ftype_info = {} - reader = GenericImageFileHandler(fname, fname_info, ftype_info) - - foo = make_dataid(name="image") - assert reader.file_content - assert reader.finfo["filename"] == fname - assert reader.finfo["start_time"] == self.date - assert reader.finfo["end_time"] == self.date - assert reader.area == self.area_def - assert reader.get_area_def(None) == self.area_def - assert reader.start_time == self.date - assert reader.end_time == self.date - - dataset = reader.get_dataset(foo, {}) - assert isinstance(dataset, xr.DataArray) - assert "spatial_ref" in dataset.coords - assert np.all(np.isnan(dataset.data[:, :10, :10].compute())) - - def test_GenericImageFileHandler_masking_only_integer(self): - """Test direct use of the reader.""" - from satpy.readers.generic_image import GenericImageFileHandler - - class FakeGenericImageFileHandler(GenericImageFileHandler): - - def __init__(self, filename, filename_info, filetype_info, file_content, **kwargs): - """Get fake file content from 'get_test_content'.""" - super(GenericImageFileHandler, self).__init__(filename, filename_info, filetype_info) - self.file_content = file_content - self.dataset_name = None - self.file_content.update(kwargs) - - data = self.scn["rgba"] - - # do nothing if not integer - float_data = data / 255. - reader = FakeGenericImageFileHandler("dummy", {}, {}, {"image": float_data}) - assert reader.get_dataset(make_dataid(name="image"), {}) is float_data - - # masking if integer - data = data.astype(np.uint32) - assert data.bands.size == 4 - reader = FakeGenericImageFileHandler("dummy", {}, {}, {"image": data}) - ret_data = reader.get_dataset(make_dataid(name="image"), {}) - assert ret_data.bands.size == 3 - - def test_GenericImageFileHandler_datasetid(self): - """Test direct use of the reader.""" - from satpy.readers.generic_image import GenericImageFileHandler - - fname = os.path.join(self.base_dir, "test_rgba.tif") - fname_info = {"start_time": self.date} - ftype_info = {} - reader = GenericImageFileHandler(fname, fname_info, ftype_info) - - foo = make_dataid(name="image-custom") - assert reader.file_content - dataset = reader.get_dataset(foo, {}) - assert isinstance(dataset, xr.DataArray) - - def test_GenericImageFileHandler_nodata(self): - """Test nodata handling with direct use of the reader.""" - from satpy.readers.generic_image import GenericImageFileHandler - - fname = os.path.join(self.base_dir, "test_l_nan_fillvalue.tif") - fname_info = {"start_time": self.date} - ftype_info = {} - reader = GenericImageFileHandler(fname, fname_info, ftype_info) - - foo = make_dataid(name="image-custom") - assert reader.file_content - info = {"nodata_handling": "nan_mask"} - dataset = reader.get_dataset(foo, info) - assert isinstance(dataset, xr.DataArray) - assert np.all(np.isnan(dataset.data[0][:10, :10].compute())) - assert np.isnan(dataset.attrs["_FillValue"]) - - info = {"nodata_handling": "fill_value"} - dataset = reader.get_dataset(foo, info) - assert isinstance(dataset, xr.DataArray) - assert np.sum(dataset.data[0][:10, :10].compute()) == 0 - assert dataset.attrs["_FillValue"] == 0 - - # default same as 'nodata_handling': 'fill_value' - dataset = reader.get_dataset(foo, {}) - assert isinstance(dataset, xr.DataArray) - assert np.sum(dataset.data[0][:10, :10].compute()) == 0 - assert dataset.attrs["_FillValue"] == 0 + "start_time": DATA_DATE}) + dset["bands"] = ["R", "G", "B"] + fname = tmp_path / "20180101_0000_test_rgb.tif" + _save_image(dset, fname, "geotiff") + + return fname + + +@pytest.fixture +def rgba_dset(random_image_channel_r, random_image_channel_g, random_image_channel_b, alpha_channel): + """Create an RGB dataset.""" + dset = xr.DataArray( + da.stack([random_image_channel_r, random_image_channel_g, random_image_channel_b, alpha_channel]), + dims=("bands", "y", "x"), + attrs={"name": "test_rgba", + "start_time": DATA_DATE}) + dset["bands"] = ["R", "G", "B", "A"] + return dset + + +@pytest.fixture +def test_image_rgba(tmp_path, rgba_dset): + """Create a test image with mode RGBA.""" + fname = tmp_path / "test_rgba.tif" + _save_image(rgba_dset, fname, "geotiff") + + return fname + + +def _save_image(dset, fname, writer, fill_value=None): + scn = Scene() + scn["data"] = dset + scn["data"].attrs["area"] = AREA_DEFINITION + scn.save_dataset("data", str(fname), writer=writer, fill_value=fill_value) + + +def test_png_scene_l_mode(test_image_l): + """Test reading a PNG image with L mode via satpy.Scene().""" + with pytest.warns(NotGeoreferencedWarning, match=r"Dataset has no geotransform"): + scn = Scene(reader="generic_image", filenames=[test_image_l]) + scn.load(["image"]) + assert scn["image"].shape == (1, Y_SIZE, X_SIZE) + assert scn.sensor_names == {"images"} + assert scn.start_time is None + assert scn.end_time is None + assert "area" not in scn["image"].attrs + + +def test_png_scene_la_mode(test_image_la): + """Test reading a PNG image with LA mode via satpy.Scene().""" + with pytest.warns(NotGeoreferencedWarning, match=r"Dataset has no geotransform"): + scn = Scene(reader="generic_image", filenames=[test_image_la]) + scn.load(["image"]) + data = da.compute(scn["image"].data) + assert scn["image"].shape == (1, Y_SIZE, X_SIZE) + assert scn.sensor_names == {"images"} + assert scn.start_time == DATA_DATE + assert scn.end_time == DATA_DATE + assert "area" not in scn["image"].attrs + assert np.sum(np.isnan(data)) == 100 + + +def test_geotiff_scene_rgb(test_image_rgb): + """Test reading geotiff image in RGB mode via satpy.Scene().""" + scn = Scene(reader="generic_image", filenames=[test_image_rgb]) + scn.load(["image"]) + assert scn["image"].shape == (3, Y_SIZE, X_SIZE) + assert scn.sensor_names == {"images"} + assert scn.start_time == DATA_DATE + assert scn.end_time == DATA_DATE + assert scn["image"].area == AREA_DEFINITION + + +def test_geotiff_scene_rgba(test_image_rgba): + """Test reading geotiff image in RGBA mode via satpy.Scene().""" + scn = Scene(reader="generic_image", filenames=[test_image_rgba]) + scn.load(["image"]) + assert scn["image"].shape == (3, Y_SIZE, X_SIZE) + assert scn.sensor_names == {"images"} + assert scn.start_time is None + assert scn.end_time is None + assert scn["image"].area == AREA_DEFINITION + + +def test_geotiff_scene_nan_fill_value(test_image_l_nan_fill_value): + """Test reading geotiff image with fill value set via satpy.Scene().""" + scn = Scene(reader="generic_image", filenames=[test_image_l_nan_fill_value]) + scn.load(["image"]) + assert scn["image"].shape == (1, Y_SIZE, X_SIZE) + assert np.sum(scn["image"].data[0][:10, :10].compute()) == 0 + + +def test_geotiff_scene_nan(test_image_l_nan): + """Test reading geotiff image with NaN values in it via satpy.Scene().""" + scn = Scene(reader="generic_image", filenames=[test_image_l_nan]) + scn.load(["image"]) + assert scn["image"].shape == (1, Y_SIZE, X_SIZE) + assert np.all(np.isnan(scn["image"].data[0][:10, :10].compute())) + + +def test_GenericImageFileHandler(test_image_rgba): + """Test direct use of the reader.""" + from satpy.readers.generic_image import GenericImageFileHandler + + fname_info = {"start_time": DATA_DATE} + ftype_info = {} + reader = GenericImageFileHandler(test_image_rgba, fname_info, ftype_info) + + data_id = make_dataid(name="image") + assert reader.file_content + assert reader.finfo["filename"] == test_image_rgba + assert reader.finfo["start_time"] == DATA_DATE + assert reader.finfo["end_time"] == DATA_DATE + assert reader.area == AREA_DEFINITION + assert reader.get_area_def(None) == AREA_DEFINITION + assert reader.start_time == DATA_DATE + assert reader.end_time == DATA_DATE + + dataset = reader.get_dataset(data_id, {}) + assert isinstance(dataset, xr.DataArray) + assert "spatial_ref" in dataset.coords + assert np.all(np.isnan(dataset.data[:, :10, :10].compute())) + + +class FakeGenericImageFileHandler(GenericImageFileHandler): + """Fake file handler.""" + + def __init__(self, filename, filename_info, filetype_info, file_content, **kwargs): + """Get fake file content from 'get_test_content'.""" + super(GenericImageFileHandler, self).__init__(filename, filename_info, filetype_info) + self.file_content = file_content + self.dataset_name = None + self.file_content.update(kwargs) + + +def test_GenericImageFileHandler_no_masking_for_float(rgba_dset): + """Test direct use of the reader for float_data.""" + # do nothing if not integer + float_data = rgba_dset / 255. + reader = FakeGenericImageFileHandler("dummy", {}, {}, {"image": float_data}) + assert reader.get_dataset(make_dataid(name="image"), {}) is float_data + + +def test_GenericImageFileHandler_masking_for_integer(rgba_dset): + """Test direct use of the reader for float_data.""" + # masking if integer + data = rgba_dset.astype(np.uint32) + assert data.bands.size == 4 + reader = FakeGenericImageFileHandler("dummy", {}, {}, {"image": data}) + ret_data = reader.get_dataset(make_dataid(name="image"), {}) + assert ret_data.bands.size == 3 + + +def test_GenericImageFileHandler_datasetid(test_image_rgba): + """Test direct use of the reader.""" + fname_info = {"start_time": DATA_DATE} + ftype_info = {} + reader = GenericImageFileHandler(test_image_rgba, fname_info, ftype_info) + + data_id = make_dataid(name="image-custom") + assert reader.file_content + dataset = reader.get_dataset(data_id, {}) + assert isinstance(dataset, xr.DataArray) + + +@pytest.fixture +def reader_l_nan_fill_value(test_image_l_nan_fill_value): + """Create GenericImageFileHandler.""" + fname_info = {"start_time": DATA_DATE} + ftype_info = {} + return GenericImageFileHandler(test_image_l_nan_fill_value, fname_info, ftype_info) + + +def test_GenericImageFileHandler_nodata_nan_mask(reader_l_nan_fill_value): + """Test nodata handling with direct use of the reader with nodata handling: nan_mask.""" + data_id = make_dataid(name="image-custom") + assert reader_l_nan_fill_value.file_content + info = {"nodata_handling": "nan_mask"} + dataset = reader_l_nan_fill_value.get_dataset(data_id, info) + assert isinstance(dataset, xr.DataArray) + assert np.all(np.isnan(dataset.data[0][:10, :10].compute())) + assert np.isnan(dataset.attrs["_FillValue"]) + + +def test_GenericImageFileHandler_nodata_fill_value(reader_l_nan_fill_value): + """Test nodata handling with direct use of the reader with nodata handling: fill_value.""" + info = {"nodata_handling": "fill_value"} + data_id = make_dataid(name="image-custom") + dataset = reader_l_nan_fill_value.get_dataset(data_id, info) + assert isinstance(dataset, xr.DataArray) + assert np.sum(dataset.data[0][:10, :10].compute()) == 0 + assert dataset.attrs["_FillValue"] == 0 + + +def test_GenericImageFileHandler_nodata_nan_mask_default(reader_l_nan_fill_value): + """Test nodata handling with direct use of the reader with default nodata handling.""" + data_id = make_dataid(name="image-custom") + dataset = reader_l_nan_fill_value.get_dataset(data_id, {}) + assert isinstance(dataset, xr.DataArray) + assert np.sum(dataset.data[0][:10, :10].compute()) == 0 + assert dataset.attrs["_FillValue"] == 0 From ac433d69d411c9030056d5bb3721458e7401ac4a Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Thu, 10 Oct 2024 11:05:27 +0300 Subject: [PATCH 2/4] Fix generic image reader to return float32 when float is needed --- satpy/readers/generic_image.py | 2 +- satpy/tests/reader_tests/test_generic_image.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/satpy/readers/generic_image.py b/satpy/readers/generic_image.py index 5032b7bbb8..c0e334302f 100644 --- a/satpy/readers/generic_image.py +++ b/satpy/readers/generic_image.py @@ -147,7 +147,7 @@ def _mask_image_data(data, info): if not np.issubdtype(data.dtype, np.integer): raise ValueError("Only integer datatypes can be used as a mask.") mask = data.data[-1, :, :] == np.iinfo(data.dtype).min - data = data.astype(np.float64) + data = data.astype(np.float32) masked_data = da.stack([da.where(mask, np.nan, data.data[i, :, :]) for i in range(data.shape[0])]) data.data = masked_data diff --git a/satpy/tests/reader_tests/test_generic_image.py b/satpy/tests/reader_tests/test_generic_image.py index 21e2b5b09e..1ab3e073ec 100644 --- a/satpy/tests/reader_tests/test_generic_image.py +++ b/satpy/tests/reader_tests/test_generic_image.py @@ -167,6 +167,7 @@ def test_png_scene_l_mode(test_image_l): assert scn.start_time is None assert scn.end_time is None assert "area" not in scn["image"].attrs + assert scn["image"].dtype == np.float32 def test_png_scene_la_mode(test_image_la): @@ -181,6 +182,7 @@ def test_png_scene_la_mode(test_image_la): assert scn.end_time == DATA_DATE assert "area" not in scn["image"].attrs assert np.sum(np.isnan(data)) == 100 + assert scn["image"].dtype == np.float32 def test_geotiff_scene_rgb(test_image_rgb): @@ -192,6 +194,7 @@ def test_geotiff_scene_rgb(test_image_rgb): assert scn.start_time == DATA_DATE assert scn.end_time == DATA_DATE assert scn["image"].area == AREA_DEFINITION + assert scn["image"].dtype == np.float32 def test_geotiff_scene_rgba(test_image_rgba): @@ -203,6 +206,7 @@ def test_geotiff_scene_rgba(test_image_rgba): assert scn.start_time is None assert scn.end_time is None assert scn["image"].area == AREA_DEFINITION + assert scn["image"].dtype == np.float32 def test_geotiff_scene_nan_fill_value(test_image_l_nan_fill_value): @@ -211,6 +215,7 @@ def test_geotiff_scene_nan_fill_value(test_image_l_nan_fill_value): scn.load(["image"]) assert scn["image"].shape == (1, Y_SIZE, X_SIZE) assert np.sum(scn["image"].data[0][:10, :10].compute()) == 0 + assert scn["image"].dtype == np.uint8 def test_geotiff_scene_nan(test_image_l_nan): @@ -219,6 +224,7 @@ def test_geotiff_scene_nan(test_image_l_nan): scn.load(["image"]) assert scn["image"].shape == (1, Y_SIZE, X_SIZE) assert np.all(np.isnan(scn["image"].data[0][:10, :10].compute())) + assert scn["image"].dtype == np.float32 def test_GenericImageFileHandler(test_image_rgba): From d129942fc97918a79c746a235f9d374e4fe80aae Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Thu, 10 Oct 2024 11:11:21 +0300 Subject: [PATCH 3/4] Fix add_bands() to not promote the data type when adding alpha band --- satpy/composites/__init__.py | 1 + satpy/tests/test_composites.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/satpy/composites/__init__.py b/satpy/composites/__init__.py index 0ac99c98f7..b032f23a32 100644 --- a/satpy/composites/__init__.py +++ b/satpy/composites/__init__.py @@ -985,6 +985,7 @@ def add_bands(data, bands): alpha = new_data[0].copy() alpha.data = da.ones((data.sizes["y"], data.sizes["x"]), + dtype=new_data[0].dtype, chunks=new_data[0].chunks) # Rename band to indicate it's alpha alpha["bands"] = "A" diff --git a/satpy/tests/test_composites.py b/satpy/tests/test_composites.py index 2af010e9ac..1b60161a52 100644 --- a/satpy/tests/test_composites.py +++ b/satpy/tests/test_composites.py @@ -1302,7 +1302,7 @@ def test_add_bands_l_rgb(self): from satpy.composites import add_bands # L + RGB -> RGB - data = xr.DataArray(da.ones((1, 3, 3)), dims=("bands", "y", "x"), + data = xr.DataArray(da.ones((1, 3, 3), dtype="float32"), dims=("bands", "y", "x"), coords={"bands": ["L"]}) new_bands = xr.DataArray(da.array(["R", "G", "B"]), dims=("bands"), coords={"bands": ["R", "G", "B"]}) @@ -1311,13 +1311,14 @@ def test_add_bands_l_rgb(self): assert res.attrs["mode"] == "".join(res_bands) np.testing.assert_array_equal(res.bands, res_bands) np.testing.assert_array_equal(res.coords["bands"], res_bands) + assert res.dtype == np.float32 def test_add_bands_l_rgba(self): """Test adding bands.""" from satpy.composites import add_bands # L + RGBA -> RGBA - data = xr.DataArray(da.ones((1, 3, 3)), dims=("bands", "y", "x"), + data = xr.DataArray(da.ones((1, 3, 3), dtype="float32"), dims=("bands", "y", "x"), coords={"bands": ["L"]}, attrs={"mode": "L"}) new_bands = xr.DataArray(da.array(["R", "G", "B", "A"]), dims=("bands"), coords={"bands": ["R", "G", "B", "A"]}) @@ -1326,13 +1327,14 @@ def test_add_bands_l_rgba(self): assert res.attrs["mode"] == "".join(res_bands) np.testing.assert_array_equal(res.bands, res_bands) np.testing.assert_array_equal(res.coords["bands"], res_bands) + assert res.dtype == np.float32 def test_add_bands_la_rgb(self): """Test adding bands.""" from satpy.composites import add_bands # LA + RGB -> RGBA - data = xr.DataArray(da.ones((2, 3, 3)), dims=("bands", "y", "x"), + data = xr.DataArray(da.ones((2, 3, 3), dtype="float32"), dims=("bands", "y", "x"), coords={"bands": ["L", "A"]}, attrs={"mode": "LA"}) new_bands = xr.DataArray(da.array(["R", "G", "B"]), dims=("bands"), coords={"bands": ["R", "G", "B"]}) @@ -1341,13 +1343,14 @@ def test_add_bands_la_rgb(self): assert res.attrs["mode"] == "".join(res_bands) np.testing.assert_array_equal(res.bands, res_bands) np.testing.assert_array_equal(res.coords["bands"], res_bands) + assert res.dtype == np.float32 def test_add_bands_rgb_rbga(self): """Test adding bands.""" from satpy.composites import add_bands # RGB + RGBA -> RGBA - data = xr.DataArray(da.ones((3, 3, 3)), dims=("bands", "y", "x"), + data = xr.DataArray(da.ones((3, 3, 3), dtype="float32"), dims=("bands", "y", "x"), coords={"bands": ["R", "G", "B"]}, attrs={"mode": "RGB"}) new_bands = xr.DataArray(da.array(["R", "G", "B", "A"]), dims=("bands"), @@ -1357,6 +1360,7 @@ def test_add_bands_rgb_rbga(self): assert res.attrs["mode"] == "".join(res_bands) np.testing.assert_array_equal(res.bands, res_bands) np.testing.assert_array_equal(res.coords["bands"], res_bands) + assert res.dtype == np.float32 def test_add_bands_p_l(self): """Test adding bands.""" From c6db95d92a61d8a2b2f09d378989cf626e21fc8e Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Thu, 10 Oct 2024 15:27:45 +0300 Subject: [PATCH 4/4] Put common asserts to a helper function --- .../tests/reader_tests/test_generic_image.py | 45 ++++++++----------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/satpy/tests/reader_tests/test_generic_image.py b/satpy/tests/reader_tests/test_generic_image.py index 1ab3e073ec..0d5d647420 100644 --- a/satpy/tests/reader_tests/test_generic_image.py +++ b/satpy/tests/reader_tests/test_generic_image.py @@ -162,12 +162,20 @@ def test_png_scene_l_mode(test_image_l): with pytest.warns(NotGeoreferencedWarning, match=r"Dataset has no geotransform"): scn = Scene(reader="generic_image", filenames=[test_image_l]) scn.load(["image"]) - assert scn["image"].shape == (1, Y_SIZE, X_SIZE) - assert scn.sensor_names == {"images"} - assert scn.start_time is None - assert scn.end_time is None + _assert_image_common(scn, 1, None, None, np.float32) assert "area" not in scn["image"].attrs - assert scn["image"].dtype == np.float32 + + +def _assert_image_common(scn, channels, start_time, end_time, dtype): + assert scn["image"].shape == (channels, Y_SIZE, X_SIZE) + assert scn.sensor_names == {"images"} + try: + assert scn.start_time is start_time + assert scn.end_time is end_time + except AssertionError: + assert scn.start_time == start_time + assert scn.end_time == end_time + assert scn["image"].dtype == dtype def test_png_scene_la_mode(test_image_la): @@ -176,55 +184,40 @@ def test_png_scene_la_mode(test_image_la): scn = Scene(reader="generic_image", filenames=[test_image_la]) scn.load(["image"]) data = da.compute(scn["image"].data) - assert scn["image"].shape == (1, Y_SIZE, X_SIZE) - assert scn.sensor_names == {"images"} - assert scn.start_time == DATA_DATE - assert scn.end_time == DATA_DATE - assert "area" not in scn["image"].attrs assert np.sum(np.isnan(data)) == 100 - assert scn["image"].dtype == np.float32 + assert "area" not in scn["image"].attrs + _assert_image_common(scn, 1, DATA_DATE, DATA_DATE, np.float32) def test_geotiff_scene_rgb(test_image_rgb): """Test reading geotiff image in RGB mode via satpy.Scene().""" scn = Scene(reader="generic_image", filenames=[test_image_rgb]) scn.load(["image"]) - assert scn["image"].shape == (3, Y_SIZE, X_SIZE) - assert scn.sensor_names == {"images"} - assert scn.start_time == DATA_DATE - assert scn.end_time == DATA_DATE assert scn["image"].area == AREA_DEFINITION - assert scn["image"].dtype == np.float32 + _assert_image_common(scn, 3, DATA_DATE, DATA_DATE, np.float32) def test_geotiff_scene_rgba(test_image_rgba): """Test reading geotiff image in RGBA mode via satpy.Scene().""" scn = Scene(reader="generic_image", filenames=[test_image_rgba]) scn.load(["image"]) - assert scn["image"].shape == (3, Y_SIZE, X_SIZE) - assert scn.sensor_names == {"images"} - assert scn.start_time is None - assert scn.end_time is None + _assert_image_common(scn, 3, None, None, np.float32) assert scn["image"].area == AREA_DEFINITION - assert scn["image"].dtype == np.float32 def test_geotiff_scene_nan_fill_value(test_image_l_nan_fill_value): """Test reading geotiff image with fill value set via satpy.Scene().""" scn = Scene(reader="generic_image", filenames=[test_image_l_nan_fill_value]) scn.load(["image"]) - assert scn["image"].shape == (1, Y_SIZE, X_SIZE) assert np.sum(scn["image"].data[0][:10, :10].compute()) == 0 - assert scn["image"].dtype == np.uint8 - + _assert_image_common(scn, 1, None, None, np.uint8) def test_geotiff_scene_nan(test_image_l_nan): """Test reading geotiff image with NaN values in it via satpy.Scene().""" scn = Scene(reader="generic_image", filenames=[test_image_l_nan]) scn.load(["image"]) - assert scn["image"].shape == (1, Y_SIZE, X_SIZE) assert np.all(np.isnan(scn["image"].data[0][:10, :10].compute())) - assert scn["image"].dtype == np.float32 + _assert_image_common(scn, 1, None, None, np.float32) def test_GenericImageFileHandler(test_image_rgba):