Skip to content

Commit

Permalink
add tabular-fits support for Spectrum1D mask
Browse files Browse the repository at this point in the history
  • Loading branch information
sbailey committed Nov 15, 2023
1 parent ba0ab68 commit 433608b
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 2 deletions.
13 changes: 13 additions & 0 deletions specutils/io/default_loaders/tabular_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,19 @@ def tabular_fits_writer(spectrum, file_name, hdu=1, update_header=False, **kwarg
raise ValueError("Could not convert uncertainty to StdDevUncertainty due"
" to divide-by-zero error.")

# Add mask column if present
if spectrum.mask is not None:
# work around https://github.com/astropy/astropy/issues/11963
# where int8 columns are written as bool, loosing information.
# upcast to int16 instead.
if spectrum.mask.dtype == np.int8:
mask = spectrum.mask.astype(np.int16)
else:
mask = spectrum.mask

columns.append(mask)
colnames.append('mask')

# For > 1D data transpose from row-major format
for c in range(1, len(columns)):
if columns[c].ndim > 1:
Expand Down
11 changes: 10 additions & 1 deletion specutils/io/parsing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,20 @@ def _find_spectral_column(table, columns_to_search, spectral_axis):
else:
err = None

# Check for mask
if 'mask' in table.colnames:
mask = table['mask']
if mask.ndim > 1:
mask = mask.T
else:
mask = None

# Create the Spectrum1D object and return it
if wcs is not None or spectral_axis_column is not None and flux_column is not None:
# For > 1D spectral axis transpose to row-major format and return SpectrumCollection
spectrum = Spectrum1D(flux=flux, spectral_axis=spectral_axis,
uncertainty=err, meta={'header': table.meta}, wcs=wcs)
uncertainty=err, meta={'header': table.meta}, wcs=wcs,
mask=mask)

return spectrum

Expand Down
63 changes: 62 additions & 1 deletion specutils/tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,69 @@ def test_tabular_fits_multid(tmp_path, ndim, spectral_axis):
assert quantity_allclose(spec.uncertainty.quantity,
spectrum.uncertainty.quantity)

@pytest.mark.parametrize("mask_type", [bool, np.uint8, np.int8, np.uint16, np.int16, '>i2'])
def test_tabular_fits_mask(tmp_path, mask_type):
# test mask I/O with tabular fits format
wave = np.arange(3600, 3700) * u.AA
nwave = len(wave)

def test_tabular_fits_header(tmp_path):
#- 1D Case
flux = np.random.uniform(0,1,size=nwave) * u.Jy
mask = np.zeros(flux.shape, dtype=mask_type)
mask[0] = 1

sp1 = Spectrum1D(spectral_axis=wave, flux=flux, mask=mask)
assert sp1.mask.dtype == mask.dtype

tmpfile = str(tmp_path / '_mask_tst.fits')
sp1.write(tmpfile, format='tabular-fits', overwrite=True)

sp2 = Spectrum1D.read(tmpfile)
assert np.all(sp1.spectral_axis == sp2.spectral_axis)
assert np.all(sp1.flux == sp2.flux)
assert sp2.mask is not None
assert np.all(sp1.mask == sp2.mask)

# int16 is returned as FITS-native '>i2'
if mask_type == np.int16:
assert sp1.mask.dtype.kind == sp2.mask.dtype.kind
assert sp1.mask.dtype.itemsize == sp2.mask.dtype.itemsize
elif mask_type == np.int8:
# due to https://github.com/astropy/astropy/issues/11963,
# int8 is upcast to int16 which is returned as >i2...
assert sp2.mask.dtype == np.dtype('>i2')
else:
assert sp1.mask.dtype == sp2.mask.dtype

#- 2D Case
nspec = 3
flux = np.random.uniform(0,1,size=(nspec,nwave)) * u.Jy
mask = np.zeros(flux.shape, dtype=mask_type)
mask[0,0] = 1

sp1 = Spectrum1D(spectral_axis=wave, flux=flux, mask=mask)

tmpfile = str(tmp_path / '_mask_tst.fits')
sp1.write(tmpfile, format='tabular-fits', overwrite=True)

sp2 = Spectrum1D.read(tmpfile)
assert np.all(sp1.spectral_axis == sp2.spectral_axis)
assert np.all(sp1.flux == sp2.flux)
assert sp2.mask is not None
assert np.all(sp1.mask == sp2.mask)

# int16 is returned as FITS-native '>i2'
if mask_type == np.int16:
assert sp1.mask.dtype.kind == sp2.mask.dtype.kind
assert sp1.mask.dtype.itemsize == sp2.mask.dtype.itemsize
elif mask_type == np.int8:
# due to https://github.com/astropy/astropy/issues/11963,
# int8 is upcast to int16 which is returned as >i2...
assert sp2.mask.dtype == np.dtype('>i2')
else:
assert sp1.mask.dtype == sp2.mask.dtype

def test_tabular_fits_maskheader(tmp_path):
# Create a small data set + header with reserved FITS keywords
disp = np.linspace(1, 1.2, 21) * u.AA
flux = np.random.normal(0., 1.0e-14, disp.shape[0]) * u.Jy
Expand Down

0 comments on commit 433608b

Please sign in to comment.