diff --git a/specutils/io/default_loaders/tabular_fits.py b/specutils/io/default_loaders/tabular_fits.py index 9dc7ef73d..ad3c6b8aa 100644 --- a/specutils/io/default_loaders/tabular_fits.py +++ b/specutils/io/default_loaders/tabular_fits.py @@ -163,6 +163,19 @@ def tabular_fits_writer(spectrum, file_name, hdu=1, update_header=False, **kwarg raise ValueError("Could not convert uncertainty to StdDevUncertainty due" " to divide-by-zero error.") + # Add mask column if present + if spectrum.mask is not None: + # work around https://github.com/astropy/astropy/issues/11963 + # where int8 columns are written as bool, loosing information. + # upcast to int16 instead. + if spectrum.mask.dtype == np.int8: + mask = spectrum.mask.astype(np.int16) + else: + mask = spectrum.mask + + columns.append(mask) + colnames.append('mask') + # For > 1D data transpose from row-major format for c in range(1, len(columns)): if columns[c].ndim > 1: diff --git a/specutils/io/parsing_utils.py b/specutils/io/parsing_utils.py index b2d4437ca..1485bb0d8 100644 --- a/specutils/io/parsing_utils.py +++ b/specutils/io/parsing_utils.py @@ -282,11 +282,20 @@ def _find_spectral_column(table, columns_to_search, spectral_axis): else: err = None + # Check for mask + if 'mask' in table.colnames: + mask = table['mask'] + if mask.ndim > 1: + mask = mask.T + else: + mask = None + # Create the Spectrum1D object and return it if wcs is not None or spectral_axis_column is not None and flux_column is not None: # For > 1D spectral axis transpose to row-major format and return SpectrumCollection spectrum = Spectrum1D(flux=flux, spectral_axis=spectral_axis, - uncertainty=err, meta={'header': table.meta}, wcs=wcs) + uncertainty=err, meta={'header': table.meta}, wcs=wcs, + mask=mask) return spectrum diff --git a/specutils/tests/test_loaders.py b/specutils/tests/test_loaders.py index a546466a4..87bc7f17f 100644 --- a/specutils/tests/test_loaders.py +++ b/specutils/tests/test_loaders.py @@ -594,7 +594,70 @@ def test_tabular_fits_multid(tmp_path, ndim, spectral_axis): spectrum.uncertainty.quantity) -def test_tabular_fits_header(tmp_path): +@pytest.mark.parametrize("mask_type", [bool, np.uint8, np.int8, np.uint16, np.int16, '>i2']) +def test_tabular_fits_mask(tmp_path, mask_type): + # test mask I/O with tabular fits format + wave = np.arange(3600, 3700) * u.AA + nwave = len(wave) + + # 1D Case + flux = np.random.uniform(0,1,size=nwave) * u.Jy + mask = np.zeros(flux.shape, dtype=mask_type) + mask[0] = 1 + + sp1 = Spectrum1D(spectral_axis=wave, flux=flux, mask=mask) + assert sp1.mask.dtype == mask.dtype + + tmpfile = str(tmp_path / '_mask_tst.fits') + sp1.write(tmpfile, format='tabular-fits', overwrite=True) + + sp2 = Spectrum1D.read(tmpfile) + assert np.all(sp1.spectral_axis == sp2.spectral_axis) + assert np.all(sp1.flux == sp2.flux) + assert sp2.mask is not None + assert np.all(sp1.mask == sp2.mask) + + # int16 is returned as FITS-native '>i2' + if mask_type == np.int16: + assert sp1.mask.dtype.kind == sp2.mask.dtype.kind + assert sp1.mask.dtype.itemsize == sp2.mask.dtype.itemsize + elif mask_type == np.int8: + # due to https://github.com/astropy/astropy/issues/11963, + # int8 is upcast to int16 which is returned as >i2... + assert sp2.mask.dtype == np.dtype('>i2') + else: + assert sp1.mask.dtype == sp2.mask.dtype + + # 2D Case + nspec = 3 + flux = np.random.uniform(0,1,size=(nspec,nwave)) * u.Jy + mask = np.zeros(flux.shape, dtype=mask_type) + mask[0,0] = 1 + + sp1 = Spectrum1D(spectral_axis=wave, flux=flux, mask=mask) + + tmpfile = str(tmp_path / '_mask_tst.fits') + sp1.write(tmpfile, format='tabular-fits', overwrite=True) + + sp2 = Spectrum1D.read(tmpfile) + assert np.all(sp1.spectral_axis == sp2.spectral_axis) + assert np.all(sp1.flux == sp2.flux) + assert sp2.mask is not None + assert np.all(sp1.mask == sp2.mask) + + # int16 is returned as FITS-native '>i2' + if mask_type == np.int16: + assert sp1.mask.dtype.kind == sp2.mask.dtype.kind + assert sp1.mask.dtype.itemsize == sp2.mask.dtype.itemsize + elif mask_type == np.int8: + # due to https://github.com/astropy/astropy/issues/11963, + # int8 is upcast to int16 which is returned as >i2... + assert sp2.mask.dtype == np.dtype('>i2') + else: + assert sp1.mask.dtype == sp2.mask.dtype + + +def test_tabular_fits_maskheader(tmp_path): # Create a small data set + header with reserved FITS keywords disp = np.linspace(1, 1.2, 21) * u.AA flux = np.random.normal(0., 1.0e-14, disp.shape[0]) * u.Jy