diff --git a/py/desispec/spectra.py b/py/desispec/spectra.py index 740472cdf..e75c8a795 100644 --- a/py/desispec/spectra.py +++ b/py/desispec/spectra.py @@ -18,6 +18,7 @@ import numbers import numpy as np +import astropy.table from astropy.table import Table from astropy.units import Unit @@ -839,6 +840,72 @@ def from_specutils(cls, spectra): meta=meta, extra=extra, single=single, scores=scores, scores_comments=scores_comments, extra_catalog=extra_catalog) +def _is_multitile(headers): + """ + If headers contain more than one TILEID, return True, otherwise False + + Args: + headers: list of dict-like objects + + Returns: True if more than one TILEID present, otherwise False + + Note: if none of the headers have TILEID, also return False + """ + tileids = list() + for hdr in headers: + if 'TILEID' in hdr: + tileids.append(hdr['TILEID']) + + if len(tileids)>0 and len(np.unique(tileids))>1: + return True + else: + return False + +def _remove_tile_keywords(headers): + """ + Remove tile-specific keywords from headers + + Args: + headers: list of dict-like objects + + Note: modified input headers in-place + """ + tile_keywords = ['TILEID', 'TILERA', 'TILEDEC', 'FIELDROT', 'FA_RUN', 'REQRA', 'REQDEC', + 'PMTIME', 'RUNDATE', 'FAARGS', 'MTLTIME', 'EBVFAC'] + + for hdr in headers: + for key in tile_keywords: + if key in hdr: + del hdr[key] + +def _stack_fibermaps(fibermaps): + """ + Stack fibermaps while handling meta keyword merging and numpy vs. Table + + Args: + fibermaps: astropy Table or numpy structured arrays + + Returns: stacked fibermap + + Note: all fibermaps must be the same type with same columns + """ + if isinstance(fibermaps[0], np.ndarray): + #- note named arrays need hstack not vstack + fibermap = np.hstack(fibermaps) + else: + if isinstance(fibermaps[0], astropy.table.Table): + #- copy tables so that we can update .meta, but ok to not copy underlying data + fibermaps = [fm.copy(copy_data=False) for fm in fibermaps] + headers = [fm.meta for fm in fibermaps] + if _is_multitile(headers): + _remove_tile_keywords(headers) + + fibermap = astropy.table.vstack(fibermaps) + else: + raise ValueError("Can't stack fibermaps of type {}".format( + type(fibermaps[0]))) + + return fibermap def stack(speclist): """ @@ -880,44 +947,17 @@ def stack(speclist): rdat = None if speclist[0].fibermap is not None: - if isinstance(speclist[0].fibermap, np.ndarray): - #- note named arrays need hstack not vstack - fibermap = np.hstack([sp.fibermap for sp in speclist]) - else: - import astropy.table - if isinstance(speclist[0].fibermap, astropy.table.Table): - fibermap = astropy.table.vstack([sp.fibermap for sp in speclist]) - else: - raise ValueError("Can't stack fibermaps of type {}".format( - type(speclist[0].fibermap))) + fibermap = _stack_fibermaps([sp.fibermap for sp in speclist]) else: fibermap = None if speclist[0].exp_fibermap is not None: - if isinstance(speclist[0].exp_fibermap, np.ndarray): - #- note named arrays need hstack not vstack - exp_fibermap = np.hstack([sp.exp_fibermap for sp in speclist]) - else: - import astropy.table - if isinstance(speclist[0].exp_fibermap, astropy.table.Table): - exp_fibermap = astropy.table.vstack([sp.exp_fibermap for sp in speclist]) - else: - raise ValueError("Can't stack exp_fibermaps of type {}".format( - type(speclist[0].exp_fibermap))) + exp_fibermap = _stack_fibermaps([sp.exp_fibermap for sp in speclist]) else: exp_fibermap = None if speclist[0].extra_catalog is not None: - if isinstance(speclist[0].extra_catalog, np.ndarray): - #- note named arrays need hstack not vstack - extra_catalog = np.hstack([sp.extra_catalog for sp in speclist]) - else: - import astropy.table - if isinstance(speclist[0].extra_catalog, astropy.table.Table): - extra_catalog = astropy.table.vstack([sp.extra_catalog for sp in speclist]) - else: - raise ValueError("Can't stack extra_catalogs of type {}".format( - type(speclist[0].extra_catalog))) + extra_catalog = _stack_fibermaps([sp.extra_catalog for sp in speclist]) else: extra_catalog = None @@ -937,10 +977,14 @@ def stack(speclist): else: scores = None + headers = [sp.meta.copy() for sp in speclist] + if _is_multitile(headers): + _remove_tile_keywords(headers) + sp = Spectra(bands, wave, flux, ivar, mask=mask, resolution_data=rdat, fibermap=fibermap, exp_fibermap=exp_fibermap, - meta=speclist[0].meta, extra=extra, scores=scores, + meta=headers[0], extra=extra, scores=scores, extra_catalog=extra_catalog, ) return sp diff --git a/py/desispec/test/test_spectra.py b/py/desispec/test/test_spectra.py index e5f8ed0fb..3d1d7a7da 100644 --- a/py/desispec/test/test_spectra.py +++ b/py/desispec/test/test_spectra.py @@ -478,6 +478,48 @@ def test_stack(self): sp3 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar) spx = stack([sp1, sp2, sp3]) + #- Cross-tile stacking of same TILEID keeps tile-specific keywords + sp1 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar, fibermap=self.fmap1.copy()) + sp2 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar, fibermap=self.fmap2.copy()) + sp1.fibermap.meta['TILEID'] = 1 + sp1.fibermap.meta['TILERA'] = 10 + sp1.meta['TILEID'] = 1 + sp1.meta['TILERA'] = 10 + sp2.fibermap.meta['TILEID'] = 1 + sp2.fibermap.meta['TILERA'] = 10 + sp2.meta['TILEID'] = 1 + sp2.meta['TILERA'] = 10 + + spx = stack([sp1, sp2]) + self.assertIn('TILEID', spx.meta) + self.assertIn('TILERA', spx.meta) + self.assertIn('TILEID', spx.fibermap.meta) + self.assertIn('TILERA', spx.fibermap.meta) + self.assertEqual(spx.meta['TILEID'], sp1.meta['TILEID']) + self.assertEqual(spx.meta['TILERA'], sp1.meta['TILERA']) + + #- but cross-tile stacking of different tiles drops tile-specific keywords + #- without modifying original inputs + sp2.fibermap.meta['TILEID'] = 2 + sp2.fibermap.meta['TILERA'] = 20 + sp2.meta['TILEID'] = 2 + sp2.meta['TILERA'] = 20 + + spx = stack([sp1, sp2]) + self.assertNotIn('TILEID', spx.meta) + self.assertNotIn('TILERA', spx.meta) + self.assertNotIn('TILEID', spx.fibermap.meta) + self.assertNotIn('TILERA', spx.fibermap.meta) + + self.assertIn('TILEID', sp1.meta) + self.assertIn('TILEID', sp1.fibermap.meta) + self.assertIn('TILEID', sp2.meta) + self.assertIn('TILEID', sp2.fibermap.meta) + self.assertEqual(sp1.meta['TILEID'], 1) + self.assertEqual(sp1.fibermap.meta['TILEID'], 1) + self.assertEqual(sp2.meta['TILEID'], 2) + self.assertEqual(sp2.fibermap.meta['TILEID'], 2) + def test_slice(self): """Test desispec.spectra.__getitem__""" sp1 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar,