Skip to content

Commit

Permalink
cross-tile spectra stack header cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Stephen Bailey authored and Stephen Bailey committed Aug 23, 2024
1 parent 75287e4 commit 411bfeb
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 31 deletions.
106 changes: 75 additions & 31 deletions py/desispec/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numbers

import numpy as np
import astropy.table
from astropy.table import Table
from astropy.units import Unit

Expand Down Expand Up @@ -839,6 +840,72 @@ def from_specutils(cls, spectra):
meta=meta, extra=extra, single=single, scores=scores,
scores_comments=scores_comments, extra_catalog=extra_catalog)

def _is_multitile(headers):
"""
If headers contain more than one TILEID, return True, otherwise False
Args:
headers: list of dict-like objects
Returns: True if more than one TILEID present, otherwise False
Note: if none of the headers have TILEID, also return False
"""
tileids = list()
for hdr in headers:
if 'TILEID' in hdr:
tileids.append(hdr['TILEID'])

if len(tileids)>0 and len(np.unique(tileids))>1:
return True
else:
return False

def _remove_tile_keywords(headers):
"""
Remove tile-specific keywords from headers
Args:
headers: list of dict-like objects
Note: modified input headers in-place
"""
tile_keywords = ['TILEID', 'TILERA', 'TILEDEC', 'FIELDROT', 'FA_RUN', 'REQRA', 'REQDEC',
'PMTIME', 'RUNDATE', 'FAARGS', 'MTLTIME', 'EBVFAC']

for hdr in headers:
for key in tile_keywords:
if key in hdr:
del hdr[key]

def _stack_fibermaps(fibermaps):
"""
Stack fibermaps while handling meta keyword merging and numpy vs. Table
Args:
fibermaps: astropy Table or numpy structured arrays
Returns: stacked fibermap
Note: all fibermaps must be the same type with same columns
"""
if isinstance(fibermaps[0], np.ndarray):
#- note named arrays need hstack not vstack
fibermap = np.hstack(fibermaps)
else:
if isinstance(fibermaps[0], astropy.table.Table):
#- copy tables so that we can update .meta, but ok to not copy underlying data
fibermaps = [fm.copy(copy_data=False) for fm in fibermaps]
headers = [fm.meta for fm in fibermaps]
if _is_multitile(headers):
_remove_tile_keywords(headers)

fibermap = astropy.table.vstack(fibermaps)
else:
raise ValueError("Can't stack fibermaps of type {}".format(
type(fibermaps[0])))

return fibermap

def stack(speclist):
"""
Expand Down Expand Up @@ -880,44 +947,17 @@ def stack(speclist):
rdat = None

if speclist[0].fibermap is not None:
if isinstance(speclist[0].fibermap, np.ndarray):
#- note named arrays need hstack not vstack
fibermap = np.hstack([sp.fibermap for sp in speclist])
else:
import astropy.table
if isinstance(speclist[0].fibermap, astropy.table.Table):
fibermap = astropy.table.vstack([sp.fibermap for sp in speclist])
else:
raise ValueError("Can't stack fibermaps of type {}".format(
type(speclist[0].fibermap)))
fibermap = _stack_fibermaps([sp.fibermap for sp in speclist])
else:
fibermap = None

if speclist[0].exp_fibermap is not None:
if isinstance(speclist[0].exp_fibermap, np.ndarray):
#- note named arrays need hstack not vstack
exp_fibermap = np.hstack([sp.exp_fibermap for sp in speclist])
else:
import astropy.table
if isinstance(speclist[0].exp_fibermap, astropy.table.Table):
exp_fibermap = astropy.table.vstack([sp.exp_fibermap for sp in speclist])
else:
raise ValueError("Can't stack exp_fibermaps of type {}".format(
type(speclist[0].exp_fibermap)))
exp_fibermap = _stack_fibermaps([sp.exp_fibermap for sp in speclist])
else:
exp_fibermap = None

if speclist[0].extra_catalog is not None:
if isinstance(speclist[0].extra_catalog, np.ndarray):
#- note named arrays need hstack not vstack
extra_catalog = np.hstack([sp.extra_catalog for sp in speclist])
else:
import astropy.table
if isinstance(speclist[0].extra_catalog, astropy.table.Table):
extra_catalog = astropy.table.vstack([sp.extra_catalog for sp in speclist])
else:
raise ValueError("Can't stack extra_catalogs of type {}".format(
type(speclist[0].extra_catalog)))
extra_catalog = _stack_fibermaps([sp.extra_catalog for sp in speclist])
else:
extra_catalog = None

Expand All @@ -937,10 +977,14 @@ def stack(speclist):
else:
scores = None

headers = [sp.meta.copy() for sp in speclist]
if _is_multitile(headers):
_remove_tile_keywords(headers)

sp = Spectra(bands, wave, flux, ivar,
mask=mask, resolution_data=rdat,
fibermap=fibermap, exp_fibermap=exp_fibermap,
meta=speclist[0].meta, extra=extra, scores=scores,
meta=headers[0], extra=extra, scores=scores,
extra_catalog=extra_catalog,
)
return sp
42 changes: 42 additions & 0 deletions py/desispec/test/test_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,48 @@ def test_stack(self):
sp3 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar)
spx = stack([sp1, sp2, sp3])

#- Cross-tile stacking of same TILEID keeps tile-specific keywords
sp1 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar, fibermap=self.fmap1.copy())
sp2 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar, fibermap=self.fmap2.copy())
sp1.fibermap.meta['TILEID'] = 1
sp1.fibermap.meta['TILERA'] = 10
sp1.meta['TILEID'] = 1
sp1.meta['TILERA'] = 10
sp2.fibermap.meta['TILEID'] = 1
sp2.fibermap.meta['TILERA'] = 10
sp2.meta['TILEID'] = 1
sp2.meta['TILERA'] = 10

spx = stack([sp1, sp2])
self.assertIn('TILEID', spx.meta)
self.assertIn('TILERA', spx.meta)
self.assertIn('TILEID', spx.fibermap.meta)
self.assertIn('TILERA', spx.fibermap.meta)
self.assertEqual(spx.meta['TILEID'], sp1.meta['TILEID'])
self.assertEqual(spx.meta['TILERA'], sp1.meta['TILERA'])

#- but cross-tile stacking of different tiles drops tile-specific keywords
#- without modifying original inputs
sp2.fibermap.meta['TILEID'] = 2
sp2.fibermap.meta['TILERA'] = 20
sp2.meta['TILEID'] = 2
sp2.meta['TILERA'] = 20

spx = stack([sp1, sp2])
self.assertNotIn('TILEID', spx.meta)
self.assertNotIn('TILERA', spx.meta)
self.assertNotIn('TILEID', spx.fibermap.meta)
self.assertNotIn('TILERA', spx.fibermap.meta)

self.assertIn('TILEID', sp1.meta)
self.assertIn('TILEID', sp1.fibermap.meta)
self.assertIn('TILEID', sp2.meta)
self.assertIn('TILEID', sp2.fibermap.meta)
self.assertEqual(sp1.meta['TILEID'], 1)
self.assertEqual(sp1.fibermap.meta['TILEID'], 1)
self.assertEqual(sp2.meta['TILEID'], 2)
self.assertEqual(sp2.fibermap.meta['TILEID'], 2)

def test_slice(self):
"""Test desispec.spectra.__getitem__"""
sp1 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar,
Expand Down

0 comments on commit 411bfeb

Please sign in to comment.