From e3a93d8122c05ffc59f31d38f284a10bcb722c9f Mon Sep 17 00:00:00 2001 From: oczoske Date: Sun, 17 Nov 2024 11:04:49 +0000 Subject: [PATCH] make function more robust against dimension mismatches --- scopesim/optics/image_plane_utils.py | 40 ++++------- .../mocks/py_objects/imagehdu_objects.py | 27 +++++++ .../tests/tests_optics/test_ImagePlane.py | 72 +++++++++++++------ 3 files changed, 91 insertions(+), 48 deletions(-) diff --git a/scopesim/optics/image_plane_utils.py b/scopesim/optics/image_plane_utils.py index 4e76e76e..1f06bca8 100644 --- a/scopesim/optics/image_plane_utils.py +++ b/scopesim/optics/image_plane_utils.py @@ -487,19 +487,22 @@ def rescale_imagehdu(imagehdu: fits.ImageHDU, pixel_scale: float | u.Quantity, primary_wcs = WCS(imagehdu.header, key=wcs_suffix[0]) # make sure that units are correct and zoom factor is positive + # The length of the zoom factor will be determined by imagehdu.data, + # which might differ from the dimension of primary_wcs. Here, pick + # the spatial dimensions only. pixel_scale = pixel_scale << u.Unit(primary_wcs.wcs.cunit[0]) - zoom = np.abs(primary_wcs.wcs.cdelt / pixel_scale.value) + zoom = np.abs(primary_wcs.wcs.cdelt[:2] / pixel_scale.value) + + if len(imagehdu.data.shape) == 3: + zoom = np.append(zoom, [1.]) # wavelength dimension unscaled if present + + logger.debug("zoom factor: %s", zoom) - if primary_wcs.naxis == 3: - # zoom = np.append(zoom, [1]) - zoom[2] = 1. if primary_wcs.naxis != imagehdu.data.ndim: + # FIXME: this happens often - shouldn't WCSs be trimmed down before? (OC) logger.warning("imagehdu.data.ndim is %d, but primary_wcs.naxis with " - "key %s is %d, both should be equal.", - imagehdu.data.ndim, wcs_suffix, primary_wcs.naxis) - zoom = zoom[:2] - - logger.debug("zoom %s", zoom) + "key %s is %d, both should be equal.", + imagehdu.data.ndim, wcs_suffix, primary_wcs.naxis) if all(zoom == 1.): # Nothing to do @@ -525,28 +528,15 @@ def rescale_imagehdu(imagehdu: fits.ImageHDU, pixel_scale: float | u.Quantity, logger.warning("imagehdu.data.ndim is %d, but wcs.naxis with key " "%s is %d, both should be equal.", imagehdu.data.ndim, ww.wcs.alt, ww.naxis) - # TODO: could this be ww = ww.sub(2) instead? or .celestial? - # ww = WCS(imagehdu.header, key=key, naxis=imagehdu.data.ndim) if any(ctype != "LINEAR" for ctype in ww.wcs.ctype): logger.warning("Non-linear WCS rescaled using linear procedure.") - new_crpix = (zoom + 1) / 2 + (ww.wcs.crpix - 1) * zoom - #ew_crpix = np.round(new_crpix * 2) / 2 # round to nearest half-pixel - logger.debug("new crpix %s", new_crpix) - ww.wcs.crpix = new_crpix + ww.wcs.crpix[:2] = (zoom[:2] + 1) / 2 + (ww.wcs.crpix[:2] - 1) * zoom[:2] + logger.debug("new crpix %s", ww.wcs.crpix) # Keep CDELT3 if cube... - new_cdelt = ww.wcs.cdelt[:] - new_cdelt /= zoom - ww.wcs.cdelt = new_cdelt - - # TODO: is forcing deg here really the best way? - # FIXME: NO THIS WILL MESS UP IF new_cdelt IS IN ARCSEC!!!!! - # new_cunit = [str(cunit) for cunit in ww.wcs.cunit] - # new_cunit[0] = "mm" if key == "D" else "deg" - # new_cunit[1] = "mm" if key == "D" else "deg" - # ww.wcs.cunit = new_cunit + ww.wcs.cdelt[:2] /= zoom[:2] imagehdu.header.update(ww.to_header()) diff --git a/scopesim/tests/mocks/py_objects/imagehdu_objects.py b/scopesim/tests/mocks/py_objects/imagehdu_objects.py index 18f0dd7d..a5592fae 100644 --- a/scopesim/tests/mocks/py_objects/imagehdu_objects.py +++ b/scopesim/tests/mocks/py_objects/imagehdu_objects.py @@ -73,3 +73,30 @@ def _image_hdu_three_wcs(): hdu.header.update(wcs_g.to_header()) return hdu + +def _image_hdu_3d_data(): + nx, ny = 100, 100 + nz = 3 + + # a 3D WCS + the_wcs0 = wcs.WCS(naxis=3, key="") + the_wcs0.wcs.ctype = ["LINEAR", "LINEAR", "WAVE"] + the_wcs0.wcs.cunit = ["arcsec", "arcsec", "um"] + the_wcs0.wcs.cdelt = [1, 1, 0.1] + the_wcs0.wcs.crval = [0, 0, 2.2] + the_wcs0.wcs.crpix = [(nx + 1) / 2, (ny + 1) / 2, 1] + + # a 2D WCS for spatial dimensions + the_wcsd = wcs.WCS(naxis=2, key="D") + the_wcsd.wcs.ctype = ["LINEAR", "LINEAR"] + the_wcsd.wcs.cunit = ["mm", "mm"] + the_wcsd.wcs.cdelt = [1, 1] + the_wcsd.wcs.crval = [0, 0] + the_wcsd.wcs.crpix = [(nx + 1) / 2, (ny + 1) / 2] + + image = np.ones((nz, ny, nx)) + hdr = the_wcs0.to_header() + hdr.extend(the_wcsd.to_header()) + hdu = fits.ImageHDU(data=image, header=hdr) + + return hdu diff --git a/scopesim/tests/tests_optics/test_ImagePlane.py b/scopesim/tests/tests_optics/test_ImagePlane.py index 97160d03..2a414b1c 100644 --- a/scopesim/tests/tests_optics/test_ImagePlane.py +++ b/scopesim/tests/tests_optics/test_ImagePlane.py @@ -1,6 +1,12 @@ +"""Tests for ImagePlane and some ImagePlaneUtils""" + +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +from copy import deepcopy + import pytest from pytest import approx -from copy import deepcopy import numpy as np from astropy.io import fits @@ -8,44 +14,51 @@ from astropy.table import Table from astropy import wcs +import matplotlib.pyplot as plt +from matplotlib.colors import LogNorm + import scopesim.optics.image_plane as opt_imp import scopesim.optics.image_plane_utils as imp_utils from scopesim.tests.mocks.py_objects.imagehdu_objects import \ - _image_hdu_square, _image_hdu_rect, _image_hdu_three_wcs - -import matplotlib.pyplot as plt -from matplotlib.colors import LogNorm - + _image_hdu_square, _image_hdu_rect, _image_hdu_three_wcs,\ + _image_hdu_3d_data PLOTS = False -@pytest.fixture(scope="function") -def image_hdu_rect(): +@pytest.fixture(scope="function", name="image_hdu_rect") +def fixture_image_hdu_rect(): return _image_hdu_rect() -@pytest.fixture(scope="function") -def image_hdu_rect_mm(): +@pytest.fixture(scope="function", name="image_hdu_rect_mm") +def fixture_image_hdu_rect_mm(): return _image_hdu_rect("D") -@pytest.fixture(scope="function") -def image_hdu_square(): +@pytest.fixture(scope="function", name="image_hdu_square") +def fixture_image_hdu_square(): return _image_hdu_square() -@pytest.fixture(scope="function") -def image_hdu_square_mm(): +@pytest.fixture(scope="function", name="image_hdu_square_mm") +def fixture_image_hdu_square_mm(): return _image_hdu_square("D") -@pytest.fixture(scope="function") -def image_hdu_three_wcs(): + +@pytest.fixture(scope="function", name="image_hdu_three_wcs") +def fixture_image_hdu_three_wcs(): return _image_hdu_three_wcs() -@pytest.fixture(scope="function") -def input_table(): + +@pytest.fixture(scope="function", name="image_hdu_3d_data") +def fixture_image_hdu_3d_data(): + return _image_hdu_3d_data() + + +@pytest.fixture(scope="function", name="input_table") +def fixture_input_table(): x = [-10, -10, 0, 10, 10] * u.arcsec y = [-10, 10, 0, -10, 10] * u.arcsec f = [1, 3, 1, 1, 5] @@ -54,8 +67,8 @@ def input_table(): return tbl -@pytest.fixture(scope="function") -def input_table_mm(): +@pytest.fixture(scope="function", name="input_table_mm") +def fixture_input_table_mm(): x = [-10, -10, 0, 10, 10] * u.mm y = [-10, 10, 0, -10, 10] * u.mm f = [1, 3, 1, 1, 5] @@ -312,7 +325,7 @@ def test_points_are_added_to_small_canvas(self, input_table): assert np.sum(canvas_hdu.data) == np.sum(tbl1["flux"]) if PLOTS: - "top left is green, top right is yellow" + # "top left is green, top right is yellow" plt.imshow(canvas_hdu.data, origin="lower") plt.show() @@ -328,7 +341,7 @@ def test_mm_points_are_added_to_small_canvas(self, input_table_mm): assert np.sum(canvas_hdu.data) == np.sum(tbl1["flux"]) if PLOTS: - "top left is green, top right is yellow" + # "top left is green, top right is yellow" plt.imshow(canvas_hdu.data, origin="lower") plt.show() @@ -387,7 +400,7 @@ def test_mm_points_are_added_to_massive_canvas(self, input_table_mm): if PLOTS: x, y = imp_utils.val2pix(hdr, 0, 0, "D") plt.plot(x, y, "ro") - "top left is green, top right is yellow" + # "top left is green, top right is yellow" plt.imshow(canvas_hdu.data, origin="lower") plt.show() @@ -701,6 +714,19 @@ def test_rescale_works_on_nondefault_wcs(self, image_hdu_three_wcs): assert new_hdu.header['CDELT1D'] == 20 + def test_rescale_works_on_3d_imageplane(self, image_hdu_3d_data): + pixel_scale = 0.274 + wcses = wcs.find_all_wcs(image_hdu_3d_data.header) + fact = pixel_scale / wcses[0].wcs.cdelt[0] + + new_hdu = imp_utils.rescale_imagehdu(image_hdu_3d_data, pixel_scale) + new_wcses = wcs.find_all_wcs(new_hdu.header) + + assert new_wcses[0].wcs.cdelt[0] == pixel_scale + assert new_wcses[0].wcs.cdelt[2] == wcses[0].wcs.cdelt[2] + assert new_wcses[1].wcs.cdelt[1] / fact == approx(wcses[1].wcs.cdelt[1]) + + ############################################################################### # ..todo: When you have time, reintegrate these tests, There are some good ones