From ac7b53b73ce275cfea9c5e5c7b6cf30117cc93e5 Mon Sep 17 00:00:00 2001
From: Derek Homeier <dhomeie@gwdg.de>
Date: Tue, 26 Mar 2024 20:53:27 +0100
Subject: [PATCH] WIP: "fix" incompatible WCS header from IRAF-style multispec
 formats

---
 specutils/io/default_loaders/wcs_fits.py | 47 +++++++++++++++---------
 1 file changed, 30 insertions(+), 17 deletions(-)

diff --git a/specutils/io/default_loaders/wcs_fits.py b/specutils/io/default_loaders/wcs_fits.py
index 9ac9d435c..5dc43a406 100644
--- a/specutils/io/default_loaders/wcs_fits.py
+++ b/specutils/io/default_loaders/wcs_fits.py
@@ -125,7 +125,6 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None,
             raise ValueError('No HDU with spectral data found.')
 
         header = hdulist[hdu].header
-        wcs = WCS(header)
 
         if 'BUNIT' in header:
             data = u.Quantity(hdulist[hdu].data, unit=header['BUNIT'])
@@ -188,23 +187,37 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None,
                 uunit = uunit**UNCERT_EXP[unc_type.lower()]
             uncertainty = UNCERT_REF[unc_type](u.Quantity(uncertainty, unit=uunit))
 
+    # Have to translate to WCS-recognised keywords for IRAF-Multispec format:
+    if header.get('CTYPE1', 'WAVE') == 'MULTISPE':
+        header['SYSTEM'] = 'MULTISPE'
+        if 'WAT1_001' in header:
+            # Try to extract from IRAF-style card or use Angstrom as default.
+            wat_dict = dict((rec.split('=') for rec in header['WAT1_001'].split()))
+            unit = wat_dict.get('units', 'Angstrom')
+            if hasattr(u, unit):
+                header['CUNIT1'] = unit
+            else:  # try with unit name stripped of excess plural 's'...
+                header['CUNIT1'] = unit.rstrip('s')
+            ctype_def = u.Unit(header['CUNIT1']).physical_type
+            ctype_def = 'WAVE' if 'length' in ctype_def else list(ctype_def)[0]
+            ctype1 = wat_dict.get('label', ctype_def)
+            header['CTYPE1'] = ctype1[:4].upper()
+            if verbose:
+                print(f"Extracted spectral axis '{header['CTYPE1']}' "
+                      f"with unit '{header['CUNIT1']}' from 'WAT1_001'")
+        elif header.get('CUNIT1', '') == '':
+            header['CUNIT1'] = 'Angstrom'
+            header['CTYPE1'] = 'WAVE'
+
+    wcs = WCS(header)
     if spectral_axis_unit is not None:
         wcs.wcs.cunit[0] = str(spectral_axis_unit)
-    elif wcs.wcs.cunit[0] == '' and 'WAT1_001' in header:
-        # Try to extract from IRAF-style card or use Angstrom as default.
-        wat_dict = dict((rec.split('=') for rec in header['WAT1_001'].split()))
-        unit = wat_dict.get('units', 'Angstrom')
-        if hasattr(u, unit):
-            wcs.wcs.cunit[0] = unit
-        else:  # try with unit name stripped of excess plural 's'...
-            wcs.wcs.cunit[0] = unit.rstrip('s')
-        if verbose:
-            print(f"Extracted spectral axis unit '{unit}' from 'WAT1_001'")
-    elif wcs.wcs.cunit[0] == '':
-        wcs.wcs.cunit[0] = 'Angstrom'
 
-    # Compatibility attribute for lookup_table (gwcs) WCS
+    # Compatibility attribute for lookup_table (gwcs) WCS; set physical_type for Spectrum1D.
     wcs.unit = tuple(wcs.wcs.cunit)
+    if verbose:
+        print(f"WCS spectral axis unit '{wcs.unit}'")
+        print(f"WCS physical axes '{wcs.world_axis_physical_types}'")
 
     meta = {'header': header}
 
@@ -535,12 +548,12 @@ def _read_non_linear_iraf_fits(file_obj, spectral_axis_unit=None, flux_unit=None
         wat_dict = dict((rec.split('=') for rec in header['WAT1_001'].split()))
         unit = wat_dict.get('units', 'Angstrom')
         if hasattr(u, unit):
-            spectral_axis_unit = unit
+            spectral_axis_unit = u.Unit(unit)
         else:  # try with unit name stripped of excess plural 's'...
-            spectral_axis_unit = unit.rstrip('s')
+            spectral_axis_unit = u.Unit(unit.rstrip('s'))
         if verbose:
             print(f"Extracted spectral axis unit '{spectral_axis_unit}' from 'WAT1_001'")
-    spectral_axis *= u.Unit(spectral_axis_unit)
+    spectral_axis *= spectral_axis_unit
 
     return spectral_axis, data, dict(header=header)