Skip to content

Commit

Permalink
make function more robust against dimension mismatches
Browse files Browse the repository at this point in the history
  • Loading branch information
oczoske committed Nov 17, 2024
1 parent f1a140f commit e3a93d8
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 48 deletions.
40 changes: 15 additions & 25 deletions scopesim/optics/image_plane_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,19 +487,22 @@ def rescale_imagehdu(imagehdu: fits.ImageHDU, pixel_scale: float | u.Quantity,
primary_wcs = WCS(imagehdu.header, key=wcs_suffix[0])

# make sure that units are correct and zoom factor is positive
# The length of the zoom factor will be determined by imagehdu.data,
# which might differ from the dimension of primary_wcs. Here, pick
# the spatial dimensions only.
pixel_scale = pixel_scale << u.Unit(primary_wcs.wcs.cunit[0])
zoom = np.abs(primary_wcs.wcs.cdelt / pixel_scale.value)
zoom = np.abs(primary_wcs.wcs.cdelt[:2] / pixel_scale.value)

if len(imagehdu.data.shape) == 3:
zoom = np.append(zoom, [1.]) # wavelength dimension unscaled if present

logger.debug("zoom factor: %s", zoom)

if primary_wcs.naxis == 3:
# zoom = np.append(zoom, [1])
zoom[2] = 1.
if primary_wcs.naxis != imagehdu.data.ndim:
# FIXME: this happens often - shouldn't WCSs be trimmed down before? (OC)
logger.warning("imagehdu.data.ndim is %d, but primary_wcs.naxis with "
"key %s is %d, both should be equal.",
imagehdu.data.ndim, wcs_suffix, primary_wcs.naxis)
zoom = zoom[:2]

logger.debug("zoom %s", zoom)
"key %s is %d, both should be equal.",
imagehdu.data.ndim, wcs_suffix, primary_wcs.naxis)

if all(zoom == 1.):
# Nothing to do
Expand All @@ -525,28 +528,15 @@ def rescale_imagehdu(imagehdu: fits.ImageHDU, pixel_scale: float | u.Quantity,
logger.warning("imagehdu.data.ndim is %d, but wcs.naxis with key "
"%s is %d, both should be equal.",
imagehdu.data.ndim, ww.wcs.alt, ww.naxis)
# TODO: could this be ww = ww.sub(2) instead? or .celestial?
# ww = WCS(imagehdu.header, key=key, naxis=imagehdu.data.ndim)

if any(ctype != "LINEAR" for ctype in ww.wcs.ctype):
logger.warning("Non-linear WCS rescaled using linear procedure.")

new_crpix = (zoom + 1) / 2 + (ww.wcs.crpix - 1) * zoom
#ew_crpix = np.round(new_crpix * 2) / 2 # round to nearest half-pixel
logger.debug("new crpix %s", new_crpix)
ww.wcs.crpix = new_crpix
ww.wcs.crpix[:2] = (zoom[:2] + 1) / 2 + (ww.wcs.crpix[:2] - 1) * zoom[:2]
logger.debug("new crpix %s", ww.wcs.crpix)

# Keep CDELT3 if cube...
new_cdelt = ww.wcs.cdelt[:]
new_cdelt /= zoom
ww.wcs.cdelt = new_cdelt

# TODO: is forcing deg here really the best way?
# FIXME: NO THIS WILL MESS UP IF new_cdelt IS IN ARCSEC!!!!!
# new_cunit = [str(cunit) for cunit in ww.wcs.cunit]
# new_cunit[0] = "mm" if key == "D" else "deg"
# new_cunit[1] = "mm" if key == "D" else "deg"
# ww.wcs.cunit = new_cunit
ww.wcs.cdelt[:2] /= zoom[:2]

imagehdu.header.update(ww.to_header())

Expand Down
27 changes: 27 additions & 0 deletions scopesim/tests/mocks/py_objects/imagehdu_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,30 @@ def _image_hdu_three_wcs():
hdu.header.update(wcs_g.to_header())

return hdu

def _image_hdu_3d_data():
nx, ny = 100, 100
nz = 3

# a 3D WCS
the_wcs0 = wcs.WCS(naxis=3, key="")
the_wcs0.wcs.ctype = ["LINEAR", "LINEAR", "WAVE"]
the_wcs0.wcs.cunit = ["arcsec", "arcsec", "um"]
the_wcs0.wcs.cdelt = [1, 1, 0.1]
the_wcs0.wcs.crval = [0, 0, 2.2]
the_wcs0.wcs.crpix = [(nx + 1) / 2, (ny + 1) / 2, 1]

# a 2D WCS for spatial dimensions
the_wcsd = wcs.WCS(naxis=2, key="D")
the_wcsd.wcs.ctype = ["LINEAR", "LINEAR"]
the_wcsd.wcs.cunit = ["mm", "mm"]
the_wcsd.wcs.cdelt = [1, 1]
the_wcsd.wcs.crval = [0, 0]
the_wcsd.wcs.crpix = [(nx + 1) / 2, (ny + 1) / 2]

image = np.ones((nz, ny, nx))
hdr = the_wcs0.to_header()
hdr.extend(the_wcsd.to_header())
hdu = fits.ImageHDU(data=image, header=hdr)

return hdu
72 changes: 49 additions & 23 deletions scopesim/tests/tests_optics/test_ImagePlane.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,64 @@
"""Tests for ImagePlane and some ImagePlaneUtils"""

# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

from copy import deepcopy

import pytest
from pytest import approx
from copy import deepcopy

import numpy as np
from astropy.io import fits
from astropy import units as u
from astropy.table import Table
from astropy import wcs

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

import scopesim.optics.image_plane as opt_imp
import scopesim.optics.image_plane_utils as imp_utils

from scopesim.tests.mocks.py_objects.imagehdu_objects import \
_image_hdu_square, _image_hdu_rect, _image_hdu_three_wcs

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

_image_hdu_square, _image_hdu_rect, _image_hdu_three_wcs,\
_image_hdu_3d_data

PLOTS = False


@pytest.fixture(scope="function")
def image_hdu_rect():
@pytest.fixture(scope="function", name="image_hdu_rect")
def fixture_image_hdu_rect():
return _image_hdu_rect()


@pytest.fixture(scope="function")
def image_hdu_rect_mm():
@pytest.fixture(scope="function", name="image_hdu_rect_mm")
def fixture_image_hdu_rect_mm():
return _image_hdu_rect("D")


@pytest.fixture(scope="function")
def image_hdu_square():
@pytest.fixture(scope="function", name="image_hdu_square")
def fixture_image_hdu_square():
return _image_hdu_square()


@pytest.fixture(scope="function")
def image_hdu_square_mm():
@pytest.fixture(scope="function", name="image_hdu_square_mm")
def fixture_image_hdu_square_mm():
return _image_hdu_square("D")

@pytest.fixture(scope="function")
def image_hdu_three_wcs():

@pytest.fixture(scope="function", name="image_hdu_three_wcs")
def fixture_image_hdu_three_wcs():
return _image_hdu_three_wcs()

@pytest.fixture(scope="function")
def input_table():

@pytest.fixture(scope="function", name="image_hdu_3d_data")
def fixture_image_hdu_3d_data():
return _image_hdu_3d_data()


@pytest.fixture(scope="function", name="input_table")
def fixture_input_table():
x = [-10, -10, 0, 10, 10] * u.arcsec
y = [-10, 10, 0, -10, 10] * u.arcsec
f = [1, 3, 1, 1, 5]
Expand All @@ -54,8 +67,8 @@ def input_table():
return tbl


@pytest.fixture(scope="function")
def input_table_mm():
@pytest.fixture(scope="function", name="input_table_mm")
def fixture_input_table_mm():
x = [-10, -10, 0, 10, 10] * u.mm
y = [-10, 10, 0, -10, 10] * u.mm
f = [1, 3, 1, 1, 5]
Expand Down Expand Up @@ -312,7 +325,7 @@ def test_points_are_added_to_small_canvas(self, input_table):
assert np.sum(canvas_hdu.data) == np.sum(tbl1["flux"])

if PLOTS:
"top left is green, top right is yellow"
# "top left is green, top right is yellow"
plt.imshow(canvas_hdu.data, origin="lower")
plt.show()

Expand All @@ -328,7 +341,7 @@ def test_mm_points_are_added_to_small_canvas(self, input_table_mm):
assert np.sum(canvas_hdu.data) == np.sum(tbl1["flux"])

if PLOTS:
"top left is green, top right is yellow"
# "top left is green, top right is yellow"
plt.imshow(canvas_hdu.data, origin="lower")
plt.show()

Expand Down Expand Up @@ -387,7 +400,7 @@ def test_mm_points_are_added_to_massive_canvas(self, input_table_mm):
if PLOTS:
x, y = imp_utils.val2pix(hdr, 0, 0, "D")
plt.plot(x, y, "ro")
"top left is green, top right is yellow"
# "top left is green, top right is yellow"
plt.imshow(canvas_hdu.data, origin="lower")
plt.show()

Expand Down Expand Up @@ -701,6 +714,19 @@ def test_rescale_works_on_nondefault_wcs(self, image_hdu_three_wcs):
assert new_hdu.header['CDELT1D'] == 20


def test_rescale_works_on_3d_imageplane(self, image_hdu_3d_data):
pixel_scale = 0.274
wcses = wcs.find_all_wcs(image_hdu_3d_data.header)
fact = pixel_scale / wcses[0].wcs.cdelt[0]

new_hdu = imp_utils.rescale_imagehdu(image_hdu_3d_data, pixel_scale)
new_wcses = wcs.find_all_wcs(new_hdu.header)

assert new_wcses[0].wcs.cdelt[0] == pixel_scale
assert new_wcses[0].wcs.cdelt[2] == wcses[0].wcs.cdelt[2]
assert new_wcses[1].wcs.cdelt[1] / fact == approx(wcses[1].wcs.cdelt[1])


###############################################################################
# ..todo: When you have time, reintegrate these tests, There are some good ones

Expand Down

0 comments on commit e3a93d8

Please sign in to comment.