Skip to content

Commit

Permalink
fix competing tsnr ensemble caches
Browse files Browse the repository at this point in the history
  • Loading branch information
Stephen Bailey authored and Stephen Bailey committed Aug 22, 2024
1 parent 07b01cc commit 63e0ea4
Showing 1 changed file with 28 additions and 36 deletions.
64 changes: 28 additions & 36 deletions py/desispec/tsnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,16 +405,12 @@ def write(self,filename) :
log.info('Successfully written to '+filename)

_tsnr_ensembles = None
def get_ensemble(dirpath=None, bands=["b","r","z"], smooth=0):
def get_ensemble(dirpath=None, smooth=0):
'''
Function that takes a frame object and a bitmask and
returns ivar (and optionally mask) array(s) that have fibers with
offending bits in fibermap['FIBERSTATUS'] set to
0 in ivar and optionally flips a bit in mask.
Read TSNR ensembles from $DESIMODEL/data/tsnr
Args:
dirpath: path to the dir. with ensemble dflux files. default is $DESIMODEL/data/tsnr
bands (list, optional): bands to expect, typically [BRZ] - case ignored.
smooth (int, optional): Further convolve the residual ensemble flux.
Returns:
Expand All @@ -423,17 +419,25 @@ def get_ensemble(dirpath=None, bands=["b","r","z"], smooth=0):
frequency residual for the ensemble. See doc. 4723.
'''

log = get_logger()

global _tsnr_ensembles
if _tsnr_ensembles is not None:
log.debug('Using cached TSNR ensemble')
return _tsnr_ensembles

#- all ensembles have these bands
bands = ('b', 'r', 'z')

t0 = time.time()

log=get_logger()
if dirpath is None :
dirpath = os.path.join(os.environ["DESIMODEL"],"data/tsnr")

paths = glob.glob(dirpath + '/tsnr-ensemble-*.fits')
log.info('Reading TSNR ensemble files from %s', dirpath)

paths = sorted(glob.glob(dirpath + '/tsnr-ensemble-*.fits'))

wave = {}
flux = {}
Expand All @@ -445,23 +449,22 @@ def get_ensemble(dirpath=None, bands=["b","r","z"], smooth=0):

for path in paths:
tracer = path.split('/')[-1].split('-')[2].replace('.fits','')
dat = fits.open(path)

if 'FLUXSCAL' in dat[0].header :
scale_factor = dat[0].header['FLUXSCAL']
log.info("for {} apply scale factor = {:4.3f}".format(path,scale_factor))
else :
scale_factor = 1.

for band in bands:
wave[band] = dat['WAVE_{}'.format(band.upper())].data
flux[band] = scale_factor*dat['DFLUX_{}'.format(band.upper())].data
ivar[band] = 1.e99 * np.ones_like(flux[band])

# 125: 100. A in 0.8 pixel.
if smooth > 0:
flux[band] = convolve(flux[band][0,:], Box1DKernel(smooth), boundary='extend')
flux[band] = flux[band].reshape(1, len(flux[band]))
with fits.open(path) as dat:
if 'FLUXSCAL' in dat[0].header :
scale_factor = dat[0].header['FLUXSCAL']
log.info("for {} apply scale factor = {:4.3f}".format(path,scale_factor))
else :
scale_factor = 1.

for band in bands:
wave[band] = dat['WAVE_{}'.format(band.upper())].data
flux[band] = scale_factor*dat['DFLUX_{}'.format(band.upper())].data
ivar[band] = 1.e99 * np.ones_like(flux[band])

# 125: 100. A in 0.8 pixel.
if smooth > 0:
flux[band] = convolve(flux[band][0,:], Box1DKernel(smooth), boundary='extend')
flux[band] = flux[band].reshape(1, len(flux[band]))

ensembles[tracer] = Spectra(bands, wave, flux, ivar)
ensembles[tracer].meta = dat[0].header
Expand Down Expand Up @@ -755,7 +758,6 @@ def alpha_X2(alpha):

#- Cache files from desimodel to avoid reading them N>>1 times
_camera_nea_angperpix = None
_band_ensemble = None

def calc_tsnr_fiberfracs(fibermap, etc_fiberfracs, no_offsets=False):
'''
Expand Down Expand Up @@ -984,7 +986,6 @@ def calc_tsnr2(frame, fiberflat, skymodel, fluxcalib, alpha_only=False, include_
Assumes DESIMODEL is set and up to date.
'''
global _camera_nea_angperpix
global _band_ensemble

t0=time.time()

Expand Down Expand Up @@ -1056,16 +1057,7 @@ def calc_tsnr2(frame, fiberflat, skymodel, fluxcalib, alpha_only=False, include_
nea, angperpix = read_nea(neafilename)
_camera_nea_angperpix[camera] = nea, angperpix

if _band_ensemble is None:
_band_ensemble = dict()

if band in _band_ensemble:
ensemble = _band_ensemble[band]
else:
ensembledir=os.path.join(os.environ["DESIMODEL"],"data/tsnr")
log.info("read TSNR ensemble files in {}".format(ensembledir))
ensemble = get_ensemble(ensembledir, bands=[band,])
_band_ensemble[band] = ensemble
ensemble = get_ensemble()

nspec, nwave = fluxcalib.calib.shape

Expand Down

0 comments on commit 63e0ea4

Please sign in to comment.