Skip to content

Commit

Permalink
Merge pull request #649 from DanRyanIrish/unwrap_wcs
Browse files Browse the repository at this point in the history
Convert WCS wrappers to FITS WCS.
  • Loading branch information
nabobalis authored Nov 15, 2023
2 parents 673e7ee + b4edd34 commit 3639e5c
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 0 deletions.
2 changes: 2 additions & 0 deletions changelog/649.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Provides `~ndcube.wcs.tools.unwrap_wcs_to_fitswcs`, a function to create a `astropy.wcs.WCS` instance equivalent to a sliced and/or resampled WCS instance.
Only valid if the underlying implementation of the wrapped WCS instance is also an `astropy.wcs.WCS` instance.
2 changes: 2 additions & 0 deletions docs/reference/wcs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ wcs (`ndcube.wcs`)
.. automodapi:: ndcube.wcs

.. automodapi:: ndcube.wcs.wrappers

.. automodapi:: ndcube.wcs.tools
41 changes: 41 additions & 0 deletions ndcube/wcs/tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
from astropy.time import Time
from astropy.wcs import WCS
from astropy.wcs.wcsapi import SlicedLowLevelWCS
from numpy.testing import assert_array_almost_equal, assert_array_equal

from ndcube.wcs.tools import unwrap_wcs_to_fitswcs
from ndcube.wcs.wrappers import ResampledLowLevelWCS


def test_unwrap_wcs_to_fitswcs():
# Build FITS-WCS and wrap it in different operations.
time_ref = Time("2000-01-01T00:00:00", scale="utc", format="fits")
header = {
"CTYPE1": "TIME", "CTYPE2": "WAVE", "CTYPE3": "HPLT-TAN", "CTYPE4": "HPLN-TAN",
"CUNIT1": "s", "CUNIT2": "Angstrom", "CUNIT3": "deg", "CUNIT4": "deg",
"CDELT1": 600, "CDELT2": 0.2, "CDELT3": 0.5, "CDELT4": 0.4,
"CRPIX1": 0, "CRPIX2": 0, "CRPIX3": 2, "CRPIX4": 2,
"CRVAL1": 0, "CRVAL2": 10, "CRVAL3": 0.5, "CRVAL4": 1,
"CNAME1": "time", "CNAME2": "wavelength", "CNAME3": "HPC lat", "CNAME4": "HPC lon",
"NAXIS1": 5, "NAXIS2": 9, "NAXIS3": 4, "NAXIS4": 4,
"DATEREF": time_ref.fits}
orig_wcs = WCS(header)
# Slice WCS
wcs1 = SlicedLowLevelWCS(orig_wcs, (0, 0, slice(None), slice(1, None))) # numpy order
# Resample WCS
wcs2 = ResampledLowLevelWCS(wcs1, [2, 3], offset=[0.5, 1]) # WCS order
# Slice WCS again
wcs3 = SlicedLowLevelWCS(wcs2, (slice(0, 2), slice(1, 2))) # numpy order
# Reconstruct fitswcs
output_wcs, dropped_data_dimensions = unwrap_wcs_to_fitswcs(wcs3)
# Assert output_wcs is correct
assert_array_equal(dropped_data_dimensions, np.array([True, True, False, False]))
assert isinstance(output_wcs, WCS)
assert output_wcs._naxis == [1, 2, 1, 1]
assert list(output_wcs.wcs.ctype) == ['TIME', 'WAVE', 'HPLT-TAN', 'HPLN-TAN']
world_values = output_wcs.array_index_to_world_values([0], [0], [0, 1], [0])
assert_array_almost_equal(world_values[0][0], np.array([2700]))
assert_array_almost_equal(world_values[1], np.array([1.04e-09, 1.10e-09]))
assert_array_almost_equal(world_values[2][0], np.array([1.26915033e-05]))
assert_array_almost_equal(world_values[3][0], np.array([0.60002173]))
187 changes: 187 additions & 0 deletions ndcube/wcs/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from numbers import Integral

import numpy as np
from astropy.wcs import WCS
from astropy.wcs.wcsapi import SlicedLowLevelWCS
from astropy.wcs.wcsapi.wrappers.base import BaseWCSWrapper

from ndcube.wcs.wrappers import ResampledLowLevelWCS

__all__ = ["unwrap_wcs_to_fitswcs"]


def unwrap_wcs_to_fitswcs(wcs):
"""
Create FITS-WCS equivalent to (nested) WCS wrapper object.
Underlying WCS must be FITS-WCS.
No axes are dropped from original FITS-WCS, even if sliced by an integer.
Instead, integer-sliced axes is sliced to length-1 and marked True in the
``dropped_data_axes`` output.
Currently supported wrapper classes include `astropy.wcs.wcsapi.SlicedLowLevelWCS`
and `ndcube.wcs.wrappers.ResampledLowLevelWCS`.
Parameters
----------
wcs: `~astropy.wcs.wcsapi.BaseWCSWrapper`
The WCS Wrapper object.
Base level WCS implementation must be FITS-WCS.
Returns
-------
fitswcs: `astropy.wcs.WCS`
The equivalent FITS-WCS object.
dropped_data_axes: 1-D `numpy.ndarray`
Denotes which axes must have been dropped from the data array by slicing wrappers.
Axes are in array/numpy order, reversed compared to WCS.
"""
# If wcs is already a FITS-WCS, return it.
low_level_wrapper = wcs.low_level_wcs if hasattr(wcs, "low_level_wcs") else wcs
if isinstance(low_level_wrapper, WCS):
return low_level_wrapper, np.zeros(low_level_wrapper.naxis, dtype=bool)
# Determine chain of wrappers down to the FITS-WCS.
wrapper_chain = []
while isinstance(low_level_wrapper, BaseWCSWrapper):
wrapper_chain.append(low_level_wrapper)
low_level_wrapper = low_level_wrapper._wcs
if hasattr(low_level_wrapper, "low_level_wcs"):
low_level_wrapper = low_level_wrapper.low_level_wcs
if not isinstance(low_level_wrapper, WCS):
raise TypeError(f"Base-level WCS must be type {type(WCS)}. Found: {type(low_level_wcs)}")
fitswcs = low_level_wrapper
dropped_data_axes = np.zeros(fitswcs.naxis, dtype=bool)
# Unwrap each wrapper in reverse order and edit fitswcs.
for low_level_wrapper in wrapper_chain[::-1]:
if isinstance(low_level_wrapper, SlicedLowLevelWCS):
slice_items = np.array([slice(None)] * fitswcs.naxis)
slice_items[dropped_data_axes == False] = low_level_wrapper._slices_array # numpy order
fitswcs, dda = _slice_fitswcs(fitswcs, slice_items, numpy_order=True)
dropped_data_axes[dda] = True
elif isinstance(low_level_wrapper, ResampledLowLevelWCS):
factor = np.ones(fitswcs.naxis)
offset = np.zeros(fitswcs.naxis)
kept_wcs_axes = dropped_data_axes[::-1] == False # WCS-order
factor[kept_wcs_axes] = low_level_wrapper._factor
offset[kept_wcs_axes] = low_level_wrapper._offset
fitswcs = _resample_fitswcs(fitswcs, factor, offset)
else:
raise TypeError("Unrecognized/unsupported WCS Wrapper type: {type(low_level_wrapper)}")
return fitswcs, dropped_data_axes


def _slice_fitswcs(fitswcs, slice_items, numpy_order=True, shape=None):
"""
Slice a FITS-WCS.
If an `int` is given in ``slice_items``, the corresponding axis is not dropped.
But the new 0th pixel will correspond the index given by the `int` in the
original WCS.
Parameters
----------
fitswcs: `astropy.wcs.WCS`
The FITS-WCS object to be sliced.
slice_items: iterable of `slice` objects or `int`
The slices to by applied to each axis. If an `int` is provided, the axis
is sliced to length-1, but not dropped. However, its corresponding entry
in the ``dropped_data_axes`` output is marked True.
numpy_order: `bool`
If True, slices in ``slice_items`` are in array/numpy order, which is
reversed compared to the WCS order.
shape: sequence of `int`, optional
The length of each axis. Only used if negative indices are supplied
in ``slice_items``. If not supplied, set to ``fitswcs._naxis``.
Order defined by numpy_order kwarg.
Returns
-------
sliced_wcs: `astropy.wcs.WCS`
The sliced FITS-WCS.
dropped_data_axes: 1-D `numpy.ndarray`
Denotes which axes must have been dropped from the data array by slicing wrappers.
Order of axes (numpy or WCS) is dictated by ``numpy_order`` kwarg.
"""
def negative_index_error_msg(x): return (
f"Negative indexing not supported as {x}th axis length is 0 in "
"underlying FITS-WCS. Supply axes lengths via shape kwarg.")
naxis = fitswcs.naxis
dropped_data_axes = np.zeros(naxis, dtype=bool)
# Sanitize inputs
if shape is None:
shape = fitswcs._naxis
if numpy_order:
shape = shape[::-1]
else:
if len(shape) != naxis:
raise ValueError("shape kwarg must be same length as number of pixel axes "
f"in FITS-WCS, i.e. {naxis}")
if not all(isinstance(s, Integral) for s in shape):
raise TypeError("All elements of ``shape`` must be integers. "
f"shapes types = {[type(s) for s in shape]}")
slice_items = list(slice_items)
for i, (item, len_axis) in enumerate(zip(slice_items, shape)):
if isinstance(item, Integral):
# Mark axis corresponding to int item as dropped from data array.
dropped_data_axes[i] = True
# Convert negative indices to positive equivalent.
if item < 0:
if len_axis == 0:
raise ValueError(negative_index_error_msg(i))
item = len_axis + item
# Convert int item to slice so a FITS-WCS is returned after slicing.
slice_items[i] = slice(item, item + 1)
elif isinstance(item, slice):
# Convert negative indices inside slice item to positive equivalent.
start_neg = item.start is not None and item.start < 0
stop_neg = item.stop is not None and item.stop < 0
if start_neg or stop_neg:
if len_axis == 0:
raise ValueError(negative_index_error_msg(i))
start = len_axis + item.start if start_neg else item.start
stop = len_axis + item.stop if stop_neg else item.stop
slice_items[i] = slice(start, stop, item.step)
else:
raise TypeError("All slice_items must be a slice or an int. "
f"type(slice_items[{i}]) = {type(slice_items[i])}")
# Slice WCS
sliced_wcs = fitswcs.slice(slice_items, numpy_order=numpy_order)
return sliced_wcs, dropped_data_axes


def _resample_fitswcs(fitswcs, factor, offset=0):
"""
Resample the plate scale of a FITS-WCS by a given factor.
``factor`` and ``offset`` inputs are in pixel order.
Parameters
----------
fitswcs: `astropy.wcs.WCS`
The FITS-WCS object to be resampled.
factor: 1-D array-like or scalar
The factor by which the FITS-WCS is resampled.
Must be same length as number of axes in ``fitswcs``.
If scalar, the same factor is applied to all axes.
Factors must be given in WCS-order (opposite to data axes order).
offset: 1-D array-like or scalar
The location on the initial pixel grid which corresponds to zero on the
resampled pixel grid. If scalar, the same offset is applied to all axes.
Offsets must be given in WCS-order (opposite to data axes order).
Returns
-------
resampled_wcs: `astropy.wcs.WCS`
The resampled FITS-WCS.
"""
# Sanitize inputs.
factor = np.asarray(factor)
if len(factor) != fitswcs.naxis:
raise ValueError(f"Length of factor must equal number of dimensions {fitswcs.naxis}.")
offset = np.asarray(offset)
if len(offset) != fitswcs.naxis:
raise ValueError(f"Length of offset must equal number of dimensions {fitswcs.naxis}.")
# Scale plate scale and shift by offset.
fitswcs.wcs.cdelt *= factor
fitswcs.wcs.crpix = (fitswcs.wcs.crpix + offset) / factor
fitswcs._naxis = list(np.round(np.array(fitswcs._naxis) / factor).astype(int))
return fitswcs

0 comments on commit 3639e5c

Please sign in to comment.