From 20deaacf635d77475611aa97361bccaa4af72d44 Mon Sep 17 00:00:00 2001
From: Ricky O'Steen <39831871+rosteen@users.noreply.github.com>
Date: Mon, 28 Aug 2023 10:33:20 -0400
Subject: [PATCH] Improved GWCS handling in Spectrum1D (#1074)

* Allow spectral axis to be anywhere, instead of forcing it to be last (#1033)

* Starting to work on flexible spectral axis location

Debugging initial spectrum creation

Set private attribute here

Working on debugging failing tests

More things are temporarily broken, but I don't want to lose this work so I'm committing here

Set spectral axis index to 0 if flux is None

Working through test failures

Fix codestyle

Allow passing spectral_axis_index to wcs_fits loader

Require specification of spectral_axis_index if WCS is 1D and flux is multi-D

Decrement spectral_axis_index when slicing with integers

Propagate spectral_axis_index through resampling

Fix last test to account for spectral axis staying first

Fix codestyle

Specify spectral_axis_index in SDSS plate loader

Greatly simply extract_bounding_spectral_region

Account for variable spectral axis location in moment calculation, fix doc example

Working on SpectrumCollection moment handling...not sure this is the way

Need to add one to the axis index here

Update narrative docs to reflect updates

* Add back in the option to move the spectral axis to last, for back-compatibility

Work around pixel unit slicing failure

Change order on crop example

Fix spectral slice handling in tuple input case (e.g. crop)

Update output of crop example

* Apply suggestions from code review

Co-authored-by: Adam Ginsburg <keflavich@gmail.com>

Apply suggestion from code review

Add helpful comment

* Address review comment about move_spectral_axis, more docs

* Add suggested line to docstring

Co-authored-by: Erik Tollerud <erik.tollerud@gmail.com>

* Add convenience method

Make this a docstring

* Add v2.0.0 changelog section

---------

Co-authored-by: Erik Tollerud <erik.tollerud@gmail.com>

* Prepare changelog for 1.10.0 release

* Fix Changelog

* Fixed issues with ndcube 2.1 docs

* Fix incorrect fluxes and uncertainties returned by FluxConservingResampler, increase computation speed  (#1060)

* new implementation of flux conserving resample

* removed unused method

* handle multi dimensional flux inputs

* .

* Update CHANGES.rst

Co-authored-by: Erik Tollerud <erik.tollerud@gmail.com>

* omit removing units

* added test to compare output to output from running SpectRes

---------

Co-authored-by: Erik Tollerud <erik.tollerud@gmail.com>

* Update changelog for 1.11.0 release

* Changelog back to unreleased

* Working on retaining full GWCS information in Spectrum1D rather than just spectral coords

* Handle getting the spectral axis out of a GWCS

Add changelog heading

Remove debugging prints

Fix changelog

Fix codestyle

* Add changelog entry

* Delete the commented-out old wavelength parsing code

* More accurate changelog

---------

Co-authored-by: Erik Tollerud <erik.tollerud@gmail.com>
Co-authored-by: Nabil Freij <nabil.freij@gmail.com>
Co-authored-by: Clare Shanahan <cshanahan@stsci.edu>
---
 CHANGES.rst                                   |  5 ++
 specutils/io/default_loaders/jwst_reader.py   | 44 ++------------
 .../default_loaders/tests/test_jwst_reader.py |  2 +-
 specutils/spectra/spectrum1d.py               | 59 ++++++++++++++-----
 4 files changed, 54 insertions(+), 56 deletions(-)

diff --git a/CHANGES.rst b/CHANGES.rst
index d66979614..e05078dfb 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -7,6 +7,11 @@ New Features
 - Spectral axis can now be any axis, rather than being forced to be last. See docs
   for more details. [#1033]
 
+- Spectrum1D now properly handles GWCS input for wcs attribute. [#1074]
+
+- JWST reader no longer transposes the input data cube for 3D data and retains
+  full GWCS information (including spatial). [#1074]
+
 Other Changes and Additions
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/specutils/io/default_loaders/jwst_reader.py b/specutils/io/default_loaders/jwst_reader.py
index 423449c05..05e5b0e4a 100644
--- a/specutils/io/default_loaders/jwst_reader.py
+++ b/specutils/io/default_loaders/jwst_reader.py
@@ -8,8 +8,6 @@
 from astropy.table import Table
 from astropy.io import fits
 from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance
-from astropy.time import Time
-from astropy.wcs import WCS
 from gwcs.wcstools import grid_from_bounding_box
 
 from ...spectra import Spectrum1D, SpectrumList
@@ -579,38 +577,9 @@ def _jwst_s3d_loader(filename, **kwargs):
             except (ValueError, KeyError):
                 flux_unit = None
 
-            # The spectral axis is first.  We need it last
-            flux_array = hdu.data.T
+            flux_array = hdu.data
             flux = Quantity(flux_array, unit=flux_unit)
 
-            # Get the wavelength array from the GWCS object which returns a
-            # tuple of (RA, Dec, lambda).
-            # Since the spatial and spectral axes are orthogonal in s3d data,
-            # it is much faster to compute a slice down the spectral axis.
-            grid = grid_from_bounding_box(wcs.bounding_box)[:, :, 0, 0]
-            _, _, wavelength_array = wcs(*grid)
-            _, _, wavelength_unit = wcs.output_frame.unit
-
-            wavelength = Quantity(wavelength_array, unit=wavelength_unit)
-
-            # The GWCS is currently broken for some IFUs, here we work around that
-            wcs = None
-            if wavelength.shape[0] != flux.shape[-1]:
-                # Need MJD-OBS for this workaround
-                if 'MJD-OBS' not in hdu.header:
-                    for key in ('MJD-BEG', 'DATE-OBS'):  # Possible alternatives
-                        if key in hdu.header:
-                            if key.startswith('MJD'):
-                                hdu.header['MJD-OBS'] = hdu.header[key]
-                                break
-                            else:
-                                t = Time(hdu.header[key])
-                                hdu.header['MJD-OBS'] = t.mjd
-                                break
-                wcs = WCS(hdu.header)
-                # Swap to match the flux transpose
-                wcs = wcs.swapaxes(-1, 0)
-
             # Merge primary and slit headers and dump into meta
             slit_header = hdu.header
             header = primary_header.copy()
@@ -621,7 +590,7 @@ def _jwst_s3d_loader(filename, **kwargs):
             ext_name = primary_header.get("ERREXT", "ERR")
             err_type = hdulist[ext_name].header.get("ERRTYPE", 'ERR')
             err_unit = hdulist[ext_name].header.get("BUNIT", None)
-            err_array = hdulist[ext_name].data.T
+            err_array = hdulist[ext_name].data
 
             # ERRTYPE can be one of "ERR", "IERR", "VAR", "IVAR"
             # but mostly ERR for JWST cubes
@@ -639,13 +608,10 @@ def _jwst_s3d_loader(filename, **kwargs):
 
             # get mask information
             mask_name = primary_header.get("MASKEXT", "DQ")
-            mask = hdulist[mask_name].data.T
+            mask = hdulist[mask_name].data
+
+            spec = Spectrum1D(flux=flux, wcs=wcs, meta=meta, uncertainty=err, mask=mask, spectral_axis_index=0)
 
-            if wcs is not None:
-                spec = Spectrum1D(flux=flux, wcs=wcs, meta=meta, uncertainty=err, mask=mask)
-            else:
-                spec = Spectrum1D(flux=flux, spectral_axis=wavelength, meta=meta,
-                                  uncertainty=err, mask=mask)
             spectra.append(spec)
 
     return SpectrumList(spectra)
diff --git a/specutils/io/default_loaders/tests/test_jwst_reader.py b/specutils/io/default_loaders/tests/test_jwst_reader.py
index 80b6d7799..50523b2f7 100644
--- a/specutils/io/default_loaders/tests/test_jwst_reader.py
+++ b/specutils/io/default_loaders/tests/test_jwst_reader.py
@@ -434,7 +434,7 @@ def test_jwst_s3d_single(tmp_path, cube):
 
     data = Spectrum1D.read(tmpfile, format='JWST s3d')
     assert type(data) is Spectrum1D
-    assert data.shape == (10, 10, 30)
+    assert data.shape == (30, 10, 10)
     assert data.uncertainty is not None
     assert data.mask is not None
     assert data.uncertainty.unit == 'MJy'
diff --git a/specutils/spectra/spectrum1d.py b/specutils/spectra/spectrum1d.py
index d57b5a34b..3bf4d4b07 100644
--- a/specutils/spectra/spectrum1d.py
+++ b/specutils/spectra/spectrum1d.py
@@ -3,9 +3,11 @@
 
 import numpy as np
 from astropy import units as u
+from astropy.coordinates import SpectralCoord
 from astropy.utils.decorators import lazyproperty
 from astropy.utils.decorators import deprecated
 from astropy.nddata import NDUncertainty, NDIOMixin, NDArithmeticMixin
+from gwcs.wcs import WCS as GWCS
 
 from .spectral_axis import SpectralAxis
 from .spectrum_mixin import OneDSpectrumMixin
@@ -228,24 +230,34 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
                     f"of the corresponding flux axis ({flux.shape[self.spectral_axis_index]})")
 
         # If a WCS is provided, determine which axis is the spectral axis
-        if wcs is not None and hasattr(wcs, "naxis"):
-            if wcs.naxis > 1:
+        if wcs is not None:
+            naxis = None
+            if hasattr(wcs, "naxis"):
+                naxis = wcs.naxis
+            # GWCS doesn't have naxis
+            elif hasattr(wcs, "world_n_dim"):
+                naxis = wcs.world_n_dim
+
+            if naxis is not None and naxis > 1:
                 temp_axes = []
                 phys_axes = wcs.world_axis_physical_types
-                for i in range(len(phys_axes)):
-                    if phys_axes[i] is None:
-                        continue
-                    if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect":
-                        temp_axes.append(i)
-                if len(temp_axes) != 1:
-                    raise ValueError("Input WCS must have exactly one axis with "
-                                     "spectral units, found {}".format(len(temp_axes)))
-                else:
-                    # Due to FITS conventions, the WCS axes are listed in opposite
-                    # order compared to the data array.
-                    self._spectral_axis_index = len(flux.shape)-temp_axes[0]-1
+                if self._spectral_axis_index is None:
+                    for i in range(len(phys_axes)):
+                        if phys_axes[i] is None:
+                            continue
+                        if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect":
+                            temp_axes.append(i)
+                    if len(temp_axes) != 1:
+                        raise ValueError("Input WCS must have exactly one axis with "
+                                        "spectral units, found {}".format(len(temp_axes)))
+                    else:
+                        # Due to FITS conventions, the WCS axes are listed in opposite
+                        # order compared to the data array.
+                        self._spectral_axis_index = len(flux.shape)-temp_axes[0]-1
 
                 if move_spectral_axis is not None:
+                    if isinstance(wcs, GWCS):
+                        raise ValueError("move_spectral_axis cannot be used with GWCS")
                     if isinstance(move_spectral_axis, str):
                         if move_spectral_axis.lower() == 'first':
                             move_to_index = 0
@@ -353,8 +365,23 @@ def __init__(self, flux=None, spectral_axis=None, spectral_axis_index=None,
                     spec_axis = self.wcs.spectral.pixel_to_world(
                                     np.arange(self.flux.shape[self.spectral_axis_index]))
             else:
-                spec_axis = self.wcs.pixel_to_world(
-                                np.arange(self.flux.shape[self.spectral_axis_index]))
+                # We now keep the entire GWCS, including spatial information, so we need to include
+                # all axes in the pixel_to_world call. Note that this assumes/requires that the
+                # dispersion is the same at all spatial locations.
+                wcs_args = []
+                for i in range(len(self.flux.shape)):
+                    wcs_args.append(np.zeros(self.flux.shape[self.spectral_axis_index]))
+                # Replace with arange for the spectral axis
+                wcs_args[self.spectral_axis_index] = np.arange(self.flux.shape[self.spectral_axis_index])
+                wcs_args.reverse()
+                temp_coords = self.wcs.pixel_to_world(*wcs_args)
+                # If there are spatial axes, temp_coords will have a SkyCoord and a SpectralCoord
+                if isinstance(temp_coords, list):
+                    for coords in temp_coords:
+                        if isinstance(coords, SpectralCoord):
+                            spec_axis = coords
+                else:
+                    spec_axis = temp_coords
 
             try:
                 if spec_axis.unit.is_equivalent(u.one):