From bf3f760d9fe81a21ac62e0bce427e3a9a1830dfb Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 29 Mar 2024 11:53:19 -0400 Subject: [PATCH 01/11] enable selection of n_brightest, disperse each chunk only once --- jwst/wfss_contam/observations.py | 63 +++++++++++++++++++++++++------- jwst/wfss_contam/wfss_contam.py | 56 +++++++++++++++++++--------- 2 files changed, 88 insertions(+), 31 deletions(-) diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index 19aa05d7f8..b88343f8fb 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -85,6 +85,11 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, # Create pixel lists for sources labeled in segmentation map self.create_pixel_list() + # Initialize the list of slits + self.simul_slits = datamodels.MultiSlitModel() + self.simul_slits_order = [] + self.simul_slits_sid = [] + def create_pixel_list(self): # Create a list of pixels to be dispersed, grouped per object ID. @@ -177,7 +182,8 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): self.simulated_image = np.zeros(self.dims, float) # Loop over all source ID's from segmentation map - for i in range(len(self.IDs)): + pool_args = [] + for i in self.IDs: if self.cache: self.cached_object[i] = {} self.cached_object[i]['x'] = [] @@ -188,8 +194,18 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): self.cached_object[i]['maxx'] = [] self.cached_object[i]['miny'] = [] self.cached_object[i]['maxy'] = [] - - self.disperse_chunk(i, order, wmin, wmax, sens_waves, sens_resp) + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp] + pool_args.append(disperse_chunk_args) + if self.max_cpu > 1: + with Pool(self.max_cpu) as mypool: + simul_slits = mypool.map(self.disperse_chunk, pool_args) + self.simul_slits.slits = simul_slits + # to do - don't think this will be able to access class variables + else: + for i in range(len(self.IDs)): + slit = self.disperse_chunk(*pool_args[i]) + if slit is not None: + self.simul_slits.slits.append(slit) def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): """ @@ -212,7 +228,7 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): Response (flux calibration) array from photom reference file """ - sid = int(self.IDs[c]) + sid = int(c) self.order = order self.wmin = wmin self.wmax = wmax @@ -258,19 +274,19 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): # now have full pars list for all pixels for this object time1 = time.time() - if self.max_cpu > 1: - mypool = Pool(self.max_cpu) # Create the pool - all_res = mypool.imap_unordered(dispersed_pixel, pars) # Fill the pool - mypool.close() # Drain the pool - else: - all_res = [] - for i in range(len(pars)): - all_res.append(dispersed_pixel(*pars[i])) + + all_res = [] + for i in range(len(pars)): + all_res.append(dispersed_pixel(*pars[i])) + + time11 = time.time() + print(f'Elapsed time for dispersed_pixel in sid {sid}:', time11-time1) # Initialize blank image for this source this_object = np.zeros(self.dims, float) nres = 0 + bounds = [] for pp in all_res: if pp is None: continue @@ -288,6 +304,7 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): maxy = int(max(y)) a = sparse.coo_matrix((f, (y - miny, x - minx)), shape=(maxy - miny + 1, maxx - minx + 1)).toarray() + bounds.append([minx, maxx, miny, maxy]) # Accumulate results into simulated images self.simulated_image[miny:maxy + 1, minx:maxx + 1] += a @@ -306,7 +323,27 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): time2 = time.time() log.debug(f"Elapsed time {time2-time1} sec") - return this_object + # figure out global bounds of object + if len(bounds) > 0: + bounds = np.array(bounds) + thisobj_minx = int(np.min(bounds[:, 0])) + thisobj_maxx = int(np.max(bounds[:, 1])) + thisobj_miny = int(np.min(bounds[:, 2])) + thisobj_maxy = int(np.max(bounds[:, 3])) + slit = datamodels.SlitModel() + slit.source_id = sid + slit.name = f"source_{sid}" + slit.xstart = thisobj_minx + slit.xsize = thisobj_maxx - thisobj_minx + 1 + slit.ystart = thisobj_miny + slit.ysize = thisobj_maxy - thisobj_miny + 1 + slit.meta.wcsinfo.spectral_order = self.order + slit.data = this_object[thisobj_miny:thisobj_maxy + 1, thisobj_minx:thisobj_maxx + 1] + + self.simul_slits_order.append(self.order) + self.simul_slits_sid.append(sid) + return slit + return None def disperse_all_from_cache(self, trans=None): if not self.cache: diff --git a/jwst/wfss_contam/wfss_contam.py b/jwst/wfss_contam/wfss_contam.py index 841538442f..52a125be8e 100644 --- a/jwst/wfss_contam/wfss_contam.py +++ b/jwst/wfss_contam/wfss_contam.py @@ -43,6 +43,9 @@ def contam_corr(input_model, waverange, photom, max_cores): Contamination estimate images for each source slit """ + n_sources = 5 # number of sources to simulate, for testing. note machine has 12 cores + source_0 = 2620 # source ID to start with + # Determine number of cpu's to use for multi-processing if max_cores == 'none': ncpus = 1 @@ -115,6 +118,13 @@ def contam_corr(input_model, waverange, photom, max_cores): simul_all = None obs = Observation(image_names, seg_model, grism_wcs, filter_name, boundaries=[0, 2047, 0, 2047], offsets=[xoffset, yoffset], max_cpu=ncpus) + + # for testing, select a subset of the brightest sources, as extracted in extract2d + ids_in_extract2d = np.array(sorted([slit.source_id for slit in output_model.slits])) + good = (ids_in_extract2d >= source_0) + obs.IDs = list(ids_in_extract2d[good][:n_sources]) + log.info(f"Simulating only {n_sources} sources starting at sid {source_0}") + print(obs.IDs) # Create simulated grism image for each order and sum them up for order in spec_orders: @@ -133,41 +143,51 @@ def contam_corr(input_model, waverange, photom, max_cores): simul_model = datamodels.ImageModel(data=simul_all) simul_model.update(input_model, only="PRIMARY") + # save the simulation multislitmodel + obs.simul_slits.save("simulated_slits.fits") + simul_slit_sids = np.array(obs.simul_slits_sid) + simul_slit_orders = np.array(obs.simul_slits_order) + # Loop over all slits/sources to subtract contaminating spectra log.info("Creating contamination image for each individual source") contam_model = datamodels.MultiSlitModel() contam_model.update(input_model) + print('number of slits in output model', len(output_model.slits)) slits = [] for slit in output_model.slits: # Create simulated spectrum for this source only sid = slit.source_id order = slit.meta.wcsinfo.spectral_order - chunk = np.where(obs.IDs == sid)[0][0] # find chunk for this source - - obs.simulated_image = np.zeros(obs.dims) - obs.disperse_chunk(chunk, order, wmin[order], wmax[order], - sens_waves[order], sens_response[order]) - this_source = obs.simulated_image + good = (simul_slit_sids == sid) * (simul_slit_orders == order) + if not any(good): + continue + else: + print('Processing source', sid, 'order', order) + + good_idx = np.where(good)[0][0] + this_simul = obs.simul_slits.slits[good_idx] # Contamination estimate is full simulated image minus this source - contam = simul_all - this_source - - # Create a cutout of the contam image that matches the extent - # of the source slit - x1 = slit.xstart - 1 - y1 = slit.ystart - 1 - cutout = contam[y1:y1 + slit.ysize, x1:x1 + slit.xsize] - new_slit = datamodels.SlitModel(data=cutout) - copy_slit_info(slit, new_slit) + # cutting out the region and then subtracting + x1 = this_simul.xstart - 1 + y1 = this_simul.ystart - 1 + this_field = simul_all[y1:y1 + this_simul.ysize, x1:x1 + this_simul.xsize] + contam = this_field - this_simul.data + + # Create a new slit model for the contamination estimate + new_slit = datamodels.SlitModel(data=contam) + # some of this slit info is wrong, because output slit has different size that input slit now + # other problems may be caused by output slit having different size when subtracting real data + # need to fix this + copy_slit_info(slit, new_slit) slits.append(new_slit) - # Subtract the cutout from the source slit - slit.data -= cutout - # Save the contamination estimates for all slits contam_model.slits.extend(slits) + # at what point does the output model get updated with the contamination-corrected data? + # Set the step status to COMPLETE output_model.meta.cal_step.wfss_contam = 'COMPLETE' From ad6829126f524a48865b6aa3ddf9eff45f8e91f9 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 3 Apr 2024 09:38:25 -0400 Subject: [PATCH 02/11] fixed multiprocessing, reversed photom and wfss_contam steps --- jwst/pipeline/calwebb_spec2.py | 2 +- jwst/wfss_contam/disperse.py | 12 +- jwst/wfss_contam/observations.py | 266 ++++++++++++++++----------- jwst/wfss_contam/wfss_contam.py | 78 ++++---- jwst/wfss_contam/wfss_contam_step.py | 12 +- 5 files changed, 212 insertions(+), 158 deletions(-) diff --git a/jwst/pipeline/calwebb_spec2.py b/jwst/pipeline/calwebb_spec2.py index 85ddabec94..2d046b23e9 100644 --- a/jwst/pipeline/calwebb_spec2.py +++ b/jwst/pipeline/calwebb_spec2.py @@ -489,8 +489,8 @@ def _process_grism(self, data): calibrated = self.fringe(calibrated) calibrated = self.pathloss(calibrated) calibrated = self.barshadow(calibrated) - calibrated = self.wfss_contam(calibrated) calibrated = self.photom(calibrated) + calibrated = self.wfss_contam(calibrated) calibrated = self.pixel_replace(calibrated) return calibrated diff --git a/jwst/wfss_contam/disperse.py b/jwst/wfss_contam/disperse.py index db05c81a95..1b921de10c 100644 --- a/jwst/wfss_contam/disperse.py +++ b/jwst/wfss_contam/disperse.py @@ -3,11 +3,11 @@ from scipy.interpolate import interp1d from ..lib.winclip import get_clipped_pixels -from .sens1d import create_1d_sens +#from .sens1d import create_1d_sens def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, - sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, + seg_wcs, grism_wcs, ID, naxis, oversample_factor=2, extrapolate_sed=False, xoffset=0, yoffset=0): """ @@ -155,13 +155,15 @@ def flux(x): return None # compute 1D sensitivity array corresponding to list of wavelengths - sens, no_cal = create_1d_sens(lams, sens_waves, sens_resp) + #sens, no_cal = create_1d_sens(lams, sens_waves, sens_resp) # Compute countrates for dispersed pixels. Note that dispersed pixel # values are naturally in units of physical fluxes, so we divide out # the sensitivity (flux calibration) values to convert to units of # countrate (DN/s). - counts = flux(lams) * areas / sens - counts[no_cal] = 0. # set to zero where no flux cal info available + # flux(lams) is either single-valued (for a single direct image) + # or an array of the same length as lams (for multiple direct images in different filters) + counts = flux(lams) * areas # / sens + #counts[no_cal] = 0. # set to zero where no flux cal info available return xs, ys, areas, lams, counts, ID diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index b88343f8fb..f20b3de9ad 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -9,6 +9,7 @@ from .disperse import dispersed_pixel import logging +import warnings log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) @@ -35,8 +36,8 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, WCS object from grism image filter : str Filter name - ID : int - ID of source to process. If zero, all sources processed. + ID : int or list-like, optional + ID(s) of source to process. If zero, all sources processed. sed_file : str Name of Spectral Energy Distribution (SED) file containing datasets matching the ID in the segmentation file and each consisting of a [[lambda],[flux]] array. @@ -97,26 +98,30 @@ def create_pixel_list(self): # When ID=0, all sources in the segmentation map are processed. # This creates a huge list of all x,y pixel indices that have non-zero values # in the seg map, sorted by those indices belonging to a particular source ID. - self.xs = [] - self.ys = [] all_IDs = np.array(list(set(np.ravel(self.seg)))) all_IDs = all_IDs[all_IDs > 0] self.IDs = all_IDs log.info(f"Loading {len(all_IDs)} sources from segmentation map") - for ID in all_IDs: - ys, xs = np.nonzero(self.seg == ID) - if len(xs) > 0 and len(ys) > 0: - self.xs.append(xs) - self.ys.append(ys) - - else: + elif isinstance(self.ID, int): # Process only the given source ID log.info(f"Loading source {self.ID} from segmentation map") - ys, xs = np.nonzero(self.seg == self.ID) + self.IDs = [self.ID] + elif isinstance(self.ID, (list, np.array)): + # Process only the given list of source IDs + log.info(f"Loading {len(self.ID)} of {len(list(set(np.ravel(self.seg))))} selected sources from segmentation map") + self.IDs = self.ID + else: + raise ValueError("ID must be an integer or a list of integers") + + self.xs = [] + self.ys = [] + for ID in self.IDs: + ys, xs = np.nonzero(self.seg == ID) if len(xs) > 0 and len(ys) > 0: - self.xs = [xs] - self.ys = [ys] - self.IDs = [self.ID] + self.xs.append(xs) + self.ys.append(ys) + + print("length of xs and ys", len(self.xs), len(self.ys)) # Populate lists of direct image flux values for the sources. self.fluxes = {} @@ -155,7 +160,7 @@ def create_pixel_list(self): for i in range(len(self.IDs)): self.fluxes["sed"].append(dnew[self.ys[i], self.xs[i]]) - def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): + def disperse_all(self, order, wmin, wmax, cache=False): """ Compute dispersed pixel values for all sources identified in the segmentation map. @@ -168,10 +173,6 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): Minimum wavelength for dispersed spectra wmax : float Maximum wavelength for dispersed spectra - sens_waves : float array - Wavelength array from photom reference file - sens_resp : float array - Response (flux calibration) array from photom reference file """ if cache: log.debug("Object caching ON") @@ -183,7 +184,7 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): # Loop over all source ID's from segmentation map pool_args = [] - for i in self.IDs: + for i in range(len(self.IDs)): if self.cache: self.cached_object[i] = {} self.cached_object[i]['x'] = [] @@ -194,97 +195,123 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): self.cached_object[i]['maxx'] = [] self.cached_object[i]['miny'] = [] self.cached_object[i]['maxy'] = [] - disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp] + disperse_chunk_args = [i, order, wmin, wmax, + self.IDs[i], self.xs[i], self.ys[i], + self.fluxes, #check shape! + self.seg_wcs, self.grism_wcs, self.dims, + self.extrapolate_sed, self.xoffset, self.yoffset] pool_args.append(disperse_chunk_args) + + # call disperse_chunk with optional multiprocessing + t0 = time.time() if self.max_cpu > 1: + log.info(f"Using multiprocessing with {self.max_cpu} cores to compute dispersion") with Pool(self.max_cpu) as mypool: - simul_slits = mypool.map(self.disperse_chunk, pool_args) - self.simul_slits.slits = simul_slits - # to do - don't think this will be able to access class variables + disperse_chunk_output = mypool.starmap(self.disperse_chunk, pool_args) else: for i in range(len(self.IDs)): - slit = self.disperse_chunk(*pool_args[i]) - if slit is not None: - self.simul_slits.slits.append(slit) - - def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): + disperse_chunk_output = self.disperse_chunk(*pool_args[i]) + t1 = time.time() + log.info(f"Wall clock time for disperse_chunk order {order}: {(t1-t0):.1f} sec") + + # Collect results into simulated image and slit models + for i, this_output in enumerate(disperse_chunk_output): + [this_image, this_bounds, this_sid, this_order] = this_output + slit = self.construct_slitmodel_for_chunk(this_image, this_bounds, this_sid, this_order) + self.simulated_image += this_image + if slit is not None: + self.simul_slits.slits.append(slit) + self.simul_slits_order.append(this_order) + self.simul_slits_sid.append(this_sid) + + @staticmethod + def disperse_chunk(c, order, wmin, wmax, sid, xs, ys, fluxes_dict, seg_wcs, grism_wcs, dims, extrapolate_sed, xoffset, yoffset): """ Method that computes dispersion for a single source. To be called after create_pixel_list(). + Static method to enable parallelization Parameters ---------- c : int - Chunk (source) number to process + Chunk (source) number to process. used to index the fluxes dict order : int - Spectral order number to process + Spectral order to process wmin : float Minimum wavelength for dispersed spectra wmax : float Maximum wavelength for dispersed spectra - sens_waves : float array - Wavelength array from photom reference file - sens_resp : float array - Response (flux calibration) array from photom reference file + sid : int + Source ID + xs : np.ndarray + X-coordinates of the the central pixel of the group of pixels + surrounding the direct image pixel index + ys : np.ndarray + Y-coordinates of the the central pixel of the group of pixels + surrounding the direct image pixel index + fluxes_dict : dict + Dictionary of fluxes for each direct image. + fluxes_dict{"lams"} is the array of wavelengths previously stored in flux list + and correspond to the central wavelengths of the filters used in + the input direct image(s). For the simple case of 1 combined direct image, + this contains a single value (e.g. 4.44 for F444W). + fluxes_dict{"fluxes"} is the array of pixel values from the direct image(s). + For the simple case of 1 combined direct image, this contains a + a single value (just like "lams"). + seg_wcs : gwcs object + WCS object from segmentation map + grism_wcs : gwcs object + WCS object from grism image + dims : tuple + Dimensions of the grism image + extrapolate_sed : bool + Flag indicating whether to extrapolate wavelength range of SED + xoffset : int + Pixel offset to apply when computing the dispersion (accounts for offset from source cutout to full frame) + yoffset : int + Pixel offset to apply when computing the dispersion (accounts for offset from source cutout to full frame) + + Returns + ------- + this_object : np.ndarray + Dispersed model of segmentation map source + bounds : list + The bounds of the object + sid : int + The source ID + order : int + The spectral order number """ - - sid = int(c) - self.order = order - self.wmin = wmin - self.wmax = wmax - self.sens_waves = sens_waves - self.sens_resp = sens_resp - log.info(f"Dispersing source {sid}, order {self.order}") - pars = [] # initialize params for this object + log.info(f"Dispersing source {sid}, order {order}") # Loop over all pixels in list for object "c" - log.debug(f"source contains {len(self.xs[c])} pixels") - for i in range(len(self.xs[c])): - - # Here "i" and "ID" are just indexes into the pixel list for the object + log.debug(f"source {sid} contains {len(xs)} pixels") + all_res = [] + for i in range(len(xs)): + # Here "i" indexes the pixel list for the object # being processed, as opposed to the ID number of the object itself - ID = i - # xc, yc are the coordinates of the central pixel of the group - # of pixels surrounding the direct image pixel index width = 1.0 height = 1.0 - xc = self.xs[c][i] + 0.5 * width - yc = self.ys[c][i] + 0.5 * height + xc = xs[i] + 0.5 * width + yc = ys[i] + 0.5 * height - # "lams" is the array of wavelengths previously stored in flux list - # and correspond to the central wavelengths of the filters used in - # the input direct image(s). For the simple case of 1 combined direct image, - # this contains a single value (e.g. 4.44 for F444W). - - # "fluxes" is the array of pixel values from the direct image(s). - # For the simple case of 1 combined direct image, this contains a - # a single value (just like "lams"). fluxes, lams = map(np.array, zip(*[ - (self.fluxes[lm][c][i], lm) for lm in sorted(self.fluxes.keys()) - if self.fluxes[lm][c][i] != 0 + (fluxes_dict[lm][c][i], lm) for lm in sorted(fluxes_dict.keys()) + if fluxes_dict[lm][c][i] != 0 ])) - pars_i = (xc, yc, width, height, lams, fluxes, self.order, - self.wmin, self.wmax, self.sens_waves, self.sens_resp, - self.seg_wcs, self.grism_wcs, ID, self.dims[::-1], 2, - self.extrapolate_sed, self.xoffset, self.yoffset) - - pars.append(pars_i) - # now have full pars list for all pixels for this object - - time1 = time.time() - - all_res = [] - for i in range(len(pars)): - all_res.append(dispersed_pixel(*pars[i])) - - time11 = time.time() - print(f'Elapsed time for dispersed_pixel in sid {sid}:', time11-time1) + pars_i = (xc, yc, width, height, lams, fluxes, order, + wmin, wmax, + seg_wcs, grism_wcs, i, dims[::-1], 2, + extrapolate_sed, xoffset, yoffset) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in scalar divide") + warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice") + all_res.append(dispersed_pixel(*pars_i)) # Initialize blank image for this source - this_object = np.zeros(self.dims, float) - + this_object = np.zeros(dims, float) nres = 0 bounds = [] for pp in all_res: @@ -307,21 +334,17 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): bounds.append([minx, maxx, miny, maxy]) # Accumulate results into simulated images - self.simulated_image[miny:maxy + 1, minx:maxx + 1] += a this_object[miny:maxy + 1, minx:maxx + 1] += a - if self.cache: - self.cached_object[c]['x'].append(x) - self.cached_object[c]['y'].append(y) - self.cached_object[c]['f'].append(f) - self.cached_object[c]['w'].append(w) - self.cached_object[c]['minx'].append(minx) - self.cached_object[c]['maxx'].append(maxx) - self.cached_object[c]['miny'].append(miny) - self.cached_object[c]['maxy'].append(maxy) - - time2 = time.time() - log.debug(f"Elapsed time {time2-time1} sec") + #if self.cache: + # self.cached_object[c]['x'].append(x) + # self.cached_object[c]['y'].append(y) + # self.cached_object[c]['f'].append(f) + # self.cached_object[c]['w'].append(w) + # self.cached_object[c]['minx'].append(minx) + # self.cached_object[c]['maxx'].append(maxx) + # self.cached_object[c]['miny'].append(miny) + # self.cached_object[c]['maxy'].append(maxy) # figure out global bounds of object if len(bounds) > 0: @@ -330,20 +353,45 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): thisobj_maxx = int(np.max(bounds[:, 1])) thisobj_miny = int(np.min(bounds[:, 2])) thisobj_maxy = int(np.max(bounds[:, 3])) - slit = datamodels.SlitModel() - slit.source_id = sid - slit.name = f"source_{sid}" - slit.xstart = thisobj_minx - slit.xsize = thisobj_maxx - thisobj_minx + 1 - slit.ystart = thisobj_miny - slit.ysize = thisobj_maxy - thisobj_miny + 1 - slit.meta.wcsinfo.spectral_order = self.order - slit.data = this_object[thisobj_miny:thisobj_maxy + 1, thisobj_minx:thisobj_maxx + 1] - - self.simul_slits_order.append(self.order) - self.simul_slits_sid.append(sid) - return slit - return None + thisobj_bounds = [thisobj_minx, thisobj_maxx, thisobj_miny, thisobj_maxy] + return (this_object, thisobj_bounds, sid, order) + return (this_object, None, sid, order) + + @staticmethod + def construct_slitmodel_for_chunk(chunk_data, bounds, sid, order): + ''' + Parameters + ---------- + chunk_data : np.ndarray + Dispersed model of segmentation map source + bounds : list + The bounds of the object + sid : int + The source ID + order : int + The spectral order number + + Returns + ------- + slit : `jwst.datamodels.SlitModel` + Slit model containing the dispersed pixel values + ''' + if bounds is None: + return None + [thisobj_minx, thisobj_maxx, thisobj_miny, thisobj_maxy] = bounds + + slit = datamodels.SlitModel() + slit.source_id = sid + slit.name = f"source_{sid}" + slit.xstart = thisobj_minx + slit.xsize = thisobj_maxx - thisobj_minx + 1 + slit.ystart = thisobj_miny + slit.ysize = thisobj_maxy - thisobj_miny + 1 + slit.meta.wcsinfo.spectral_order = order + slit.data = chunk_data[thisobj_miny:thisobj_maxy + 1, thisobj_minx:thisobj_maxx + 1] + + return slit + def disperse_all_from_cache(self, trans=None): if not self.cache: diff --git a/jwst/wfss_contam/wfss_contam.py b/jwst/wfss_contam/wfss_contam.py index 52a125be8e..8a7aac6b8e 100644 --- a/jwst/wfss_contam/wfss_contam.py +++ b/jwst/wfss_contam/wfss_contam.py @@ -1,6 +1,4 @@ -# -# Top level module for WFSS contamination correction. -# +import matplotlib.pyplot as plt import logging import multiprocessing import numpy as np @@ -8,13 +6,12 @@ from stdatamodels.jwst import datamodels from .observations import Observation -from .sens1d import get_photom_data log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -def contam_corr(input_model, waverange, photom, max_cores): +def contam_corr(input_model, waverange, max_cores, n_sources=None, source_0=0): """ The main WFSS contamination correction function @@ -24,14 +21,18 @@ def contam_corr(input_model, waverange, photom, max_cores): Input data model containing 2D spectral cutouts waverange : `~jwst.datamodels.WavelengthrangeModel` Wavelength range reference file model - photom : `~jwst.datamodels.NrcWfssPhotomModel` or `~jwst.datamodels.NisWfssPhotomModel` - Photom (flux cal) reference file model max_cores : string Number of cores to use for multiprocessing. If set to 'none' (the default), then no multiprocessing will be done. The other allowable values are 'quarter', 'half', and 'all', which indicate the fraction of cores to use for multi-proc. The total number of cores includes the SMT cores (Hyper Threading for Intel). + n_sources : int + Number of sources to simulate. If None, then all sources in the + input model will be simulated. This is primarily useful for testing. + source_0 : int + Source ID to start with when selecting sources to simulate. This + is primarily useful for testing. Returns ------- @@ -43,8 +44,6 @@ def contam_corr(input_model, waverange, photom, max_cores): Contamination estimate images for each source slit """ - n_sources = 5 # number of sources to simulate, for testing. note machine has 12 cores - source_0 = 2620 # source ID to start with # Determine number of cpu's to use for multi-processing if max_cores == 'none': @@ -104,34 +103,32 @@ def contam_corr(input_model, waverange, photom, max_cores): # Load lists of wavelength ranges and flux cal info for all orders wmin = {} wmax = {} - sens_waves = {} - sens_response = {} for order in spec_orders: wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order]) wmin[order] = wavelength_range[order][0] wmax[order] = wavelength_range[order][1] - # Load the sensitivity (inverse flux cal) data for this mode and order - sens_waves[order], sens_response[order] = get_photom_data(photom, filter_kwd, pupil_kwd, order) log.debug(f"wmin={wmin}, wmax={wmax}") - # Initialize the simulated image object + # for testing, select a subset of the brightest sources, as extracted in extract2d + ids_in_extract2d = np.array([slit.source_id for slit in output_model.slits]) + good = (ids_in_extract2d >= source_0) + selected_IDs = list(ids_in_extract2d[good])[:n_sources] simul_all = None obs = Observation(image_names, seg_model, grism_wcs, filter_name, - boundaries=[0, 2047, 0, 2047], offsets=[xoffset, yoffset], max_cpu=ncpus) + boundaries=[0, 2047, 0, 2047], offsets=[xoffset, yoffset], max_cpu=ncpus, + ID=selected_IDs) - # for testing, select a subset of the brightest sources, as extracted in extract2d - ids_in_extract2d = np.array(sorted([slit.source_id for slit in output_model.slits])) - good = (ids_in_extract2d >= source_0) - obs.IDs = list(ids_in_extract2d[good][:n_sources]) - log.info(f"Simulating only {n_sources} sources starting at sid {source_0}") - print(obs.IDs) + good_slits = [slit for slit in output_model.slits if slit.source_id in obs.IDs] + #output_model.slits = good_slits #not sure why, but this fails to index properly + output_model = datamodels.MultiSlitModel() + output_model.slits.extend(good_slits) + log.info(f"Simulating only the first {n_sources} sources starting at index {source_0}") # Create simulated grism image for each order and sum them up for order in spec_orders: log.info(f"Creating full simulated grism image for order {order}") - obs.disperse_all(order, wmin[order], wmax[order], sens_waves[order], - sens_response[order]) + obs.disperse_all(order, wmin[order], wmax[order]) # Accumulate result for this order into the combined image if simul_all is None: @@ -144,7 +141,9 @@ def contam_corr(input_model, waverange, photom, max_cores): simul_model.update(input_model, only="PRIMARY") # save the simulation multislitmodel - obs.simul_slits.save("simulated_slits.fits") + obs.simul_slits.save("simulated_slits.fits", overwrite=True) + + # need to re-make these now that I changed disperse_chunk simul_slit_sids = np.array(obs.simul_slits_sid) simul_slit_orders = np.array(obs.simul_slits_order) @@ -152,31 +151,37 @@ def contam_corr(input_model, waverange, photom, max_cores): log.info("Creating contamination image for each individual source") contam_model = datamodels.MultiSlitModel() contam_model.update(input_model) - print('number of slits in output model', len(output_model.slits)) slits = [] for slit in output_model.slits: - # Create simulated spectrum for this source only + # Retrieve simulated slit for this source only sid = slit.source_id order = slit.meta.wcsinfo.spectral_order good = (simul_slit_sids == sid) * (simul_slit_orders == order) if not any(good): continue else: - print('Processing source', sid, 'order', order) + print('Subtracting contamination for source', sid, 'order', order) good_idx = np.where(good)[0][0] this_simul = obs.simul_slits.slits[good_idx] - # Contamination estimate is full simulated image minus this source - # cutting out the region and then subtracting - x1 = this_simul.xstart - 1 - y1 = this_simul.ystart - 1 - this_field = simul_all[y1:y1 + this_simul.ysize, x1:x1 + this_simul.xsize] - contam = this_field - this_simul.data - # Create a new slit model for the contamination estimate - new_slit = datamodels.SlitModel(data=contam) + fullframe_sim = np.zeros(obs.dims) + y0 = this_simul.ystart + x0 = this_simul.xstart + #print(obs.dims, this_simul.data.shape, slit.data.shape) + #print(y0, x0) + fullframe_sim[y0:y0 + this_simul.ysize, x0:x0 + this_simul.xsize] = this_simul.data + contam = simul_all - fullframe_sim + + # Create a cutout of the contam image that matches the extent + # of the source slit + x1 = slit.xstart - 1 + y1 = slit.ystart - 1 + cutout = contam[y1:y1 + slit.ysize, x1:x1 + slit.xsize] + new_slit = datamodels.SlitModel(data=cutout) + # TO DO: # some of this slit info is wrong, because output slit has different size that input slit now # other problems may be caused by output slit having different size when subtracting real data # need to fix this @@ -185,6 +190,9 @@ def contam_corr(input_model, waverange, photom, max_cores): # Save the contamination estimates for all slits contam_model.slits.extend(slits) + print('number of slits in contam model', len(contam_model.slits)) + print('number of slits in output model', len(output_model.slits)) + print('number of slits in simul model', len(obs.simul_slits.slits)) # at what point does the output model get updated with the contamination-corrected data? diff --git a/jwst/wfss_contam/wfss_contam_step.py b/jwst/wfss_contam/wfss_contam_step.py index 91ee451195..b53057410c 100755 --- a/jwst/wfss_contam/wfss_contam_step.py +++ b/jwst/wfss_contam/wfss_contam_step.py @@ -22,7 +22,7 @@ class WfssContamStep(Step): skip = boolean(default=True) """ - reference_file_types = ['photom', 'wavelengthrange'] + reference_file_types = ['wavelengthrange'] def process(self, input_model, *args, **kwargs): @@ -35,15 +35,11 @@ def process(self, input_model, *args, **kwargs): self.log.info(f'Using WAVELENGTHRANGE reference file {waverange_ref}') waverange_model = datamodels.WavelengthrangeModel(waverange_ref) - # Get the photom ref file - photom_ref = self.get_reference_file(dm, 'photom') - self.log.info(f'Using PHOTOM reference file {photom_ref}') - photom_model = datamodels.open(photom_ref) - result, simul, contam = wfss_contam.contam_corr(dm, waverange_model, - photom_model, - max_cores) + max_cores, + n_sources=12, + source_0=0) # Save intermediate results, if requested if self.save_simulated_image: From 9c38a01281a8b8c384d65c63bf2822a0911d85e5 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 3 Apr 2024 14:47:11 -0400 Subject: [PATCH 03/11] save state photom before wfss_contam --- jwst/wfss_contam/wfss_contam.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/jwst/wfss_contam/wfss_contam.py b/jwst/wfss_contam/wfss_contam.py index 8a7aac6b8e..173d5145a4 100644 --- a/jwst/wfss_contam/wfss_contam.py +++ b/jwst/wfss_contam/wfss_contam.py @@ -182,12 +182,13 @@ def contam_corr(input_model, waverange, max_cores, n_sources=None, source_0=0): cutout = contam[y1:y1 + slit.ysize, x1:x1 + slit.xsize] new_slit = datamodels.SlitModel(data=cutout) # TO DO: - # some of this slit info is wrong, because output slit has different size that input slit now - # other problems may be caused by output slit having different size when subtracting real data - # need to fix this + # not sure if the slit metadata is getting transferred properly copy_slit_info(slit, new_slit) slits.append(new_slit) + # Subtract the cutout from the source slit + slit.data -= cutout + # Save the contamination estimates for all slits contam_model.slits.extend(slits) print('number of slits in contam model', len(contam_model.slits)) From 373737feadec0ac2c8167af8b43ce1458e408285 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 5 Apr 2024 16:28:28 -0400 Subject: [PATCH 04/11] JP-2075: fixed scaling issues, some loose ends still remaining though --- jwst/pipeline/calwebb_spec2.py | 2 +- jwst/wfss_contam/disperse.py | 13 +-- jwst/wfss_contam/observations.py | 82 ++++++++++++++++--- jwst/wfss_contam/wfss_contam.py | 115 +++++++++++++++------------ jwst/wfss_contam/wfss_contam_step.py | 11 ++- 5 files changed, 152 insertions(+), 71 deletions(-) diff --git a/jwst/pipeline/calwebb_spec2.py b/jwst/pipeline/calwebb_spec2.py index 2d046b23e9..85ddabec94 100644 --- a/jwst/pipeline/calwebb_spec2.py +++ b/jwst/pipeline/calwebb_spec2.py @@ -489,8 +489,8 @@ def _process_grism(self, data): calibrated = self.fringe(calibrated) calibrated = self.pathloss(calibrated) calibrated = self.barshadow(calibrated) - calibrated = self.photom(calibrated) calibrated = self.wfss_contam(calibrated) + calibrated = self.photom(calibrated) calibrated = self.pixel_replace(calibrated) return calibrated diff --git a/jwst/wfss_contam/disperse.py b/jwst/wfss_contam/disperse.py index 1b921de10c..4ae6f521dc 100644 --- a/jwst/wfss_contam/disperse.py +++ b/jwst/wfss_contam/disperse.py @@ -1,13 +1,14 @@ import numpy as np from scipy.interpolate import interp1d +import warnings from ..lib.winclip import get_clipped_pixels -#from .sens1d import create_1d_sens +from .sens1d import create_1d_sens def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, - seg_wcs, grism_wcs, ID, naxis, + sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, oversample_factor=2, extrapolate_sed=False, xoffset=0, yoffset=0): """ @@ -155,7 +156,7 @@ def flux(x): return None # compute 1D sensitivity array corresponding to list of wavelengths - #sens, no_cal = create_1d_sens(lams, sens_waves, sens_resp) + sens, no_cal = create_1d_sens(lams, sens_waves, sens_resp) # Compute countrates for dispersed pixels. Note that dispersed pixel # values are naturally in units of physical fluxes, so we divide out @@ -163,7 +164,9 @@ def flux(x): # countrate (DN/s). # flux(lams) is either single-valued (for a single direct image) # or an array of the same length as lams (for multiple direct images in different filters) - counts = flux(lams) * areas # / sens - #counts[no_cal] = 0. # set to zero where no flux cal info available + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="divide by zero") + counts = flux(lams) * areas / (sens * oversample_factor) + counts[no_cal] = 0. # set to zero where no flux cal info available return xs, ys, areas, lams, counts, ID diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index f20b3de9ad..b27d3e71c2 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -1,6 +1,7 @@ import time import numpy as np -from multiprocessing import Pool +import multiprocessing as mp +import concurrent.futures from scipy import sparse @@ -11,10 +12,53 @@ import logging import warnings +from photutils.background import Background2D, MedianBackground +from astropy.stats import SigmaClip + log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) +def background_subtract(data, box_size=None, filter_size=(3,3), sigma=3.0, exclude_percentile=30.0): + """ + Simple astropy background subtraction + + Parameters + ---------- + data : np.ndarray + 2D array of pixel values + box_size : tuple + Size of box in pixels to use for background estimation. + If not set, defaults to 1/5 of the image size. + filter_size : tuple + Size of filter to use for background estimation + sigma : float + Sigma threshold for background clipping + exclude_percentile : float + Percentage of masked pixels above which box is excluded from background estimation + + Returns + ------- + data : np.ndarray + 2D array of pixel values with background subtracted + + Notes + ----- + Improper background subtraction in input _i2d image leads to extra flux + in the simulated dispersed image, and was one cause of flux scaling issues + in a previous version. + """ + if box_size is None: + box_size = (int(data.shape[0]/5), int(data.shape[1]/5)) + sigma_clip = SigmaClip(sigma=sigma) + bkg_estimator = MedianBackground() + bkg = Background2D(data, (500, 500), filter_size=filter_size, + sigma_clip=sigma_clip, bkg_estimator=bkg_estimator, + exclude_percentile=exclude_percentile) + + return data - bkg.background + + class Observation: """This class defines an actual observation. It is tied to a single grism image.""" @@ -121,8 +165,6 @@ def create_pixel_list(self): self.xs.append(xs) self.ys.append(ys) - print("length of xs and ys", len(self.xs), len(self.ys)) - # Populate lists of direct image flux values for the sources. self.fluxes = {} for dir_image_name in self.dir_image_names: @@ -130,6 +172,7 @@ def create_pixel_list(self): log.info(f"Using direct image {dir_image_name}") with datamodels.open(dir_image_name) as model: dimage = model.data + dimage = background_subtract(dimage) if self.sed_file is None: # Default pipeline will use sed_file=None, so we need to compute @@ -160,7 +203,7 @@ def create_pixel_list(self): for i in range(len(self.IDs)): self.fluxes["sed"].append(dnew[self.ys[i], self.xs[i]]) - def disperse_all(self, order, wmin, wmax, cache=False): + def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): """ Compute dispersed pixel values for all sources identified in the segmentation map. @@ -173,6 +216,10 @@ def disperse_all(self, order, wmin, wmax, cache=False): Minimum wavelength for dispersed spectra wmax : float Maximum wavelength for dispersed spectra + sens_waves : float array + Wavelength array from photom reference file + sens_resp : float array + Response (flux calibration) array from photom reference file """ if cache: log.debug("Object caching ON") @@ -195,9 +242,10 @@ def disperse_all(self, order, wmin, wmax, cache=False): self.cached_object[i]['maxx'] = [] self.cached_object[i]['miny'] = [] self.cached_object[i]['maxy'] = [] - disperse_chunk_args = [i, order, wmin, wmax, + + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp, self.IDs[i], self.xs[i], self.ys[i], - self.fluxes, #check shape! + self.fluxes, self.seg_wcs, self.grism_wcs, self.dims, self.extrapolate_sed, self.xoffset, self.yoffset] pool_args.append(disperse_chunk_args) @@ -206,11 +254,17 @@ def disperse_all(self, order, wmin, wmax, cache=False): t0 = time.time() if self.max_cpu > 1: log.info(f"Using multiprocessing with {self.max_cpu} cores to compute dispersion") - with Pool(self.max_cpu) as mypool: + ctx = mp.get_context("forkserver") + with ctx.Pool(self.max_cpu) as mypool: disperse_chunk_output = mypool.starmap(self.disperse_chunk, pool_args) + #with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_cpu) as executor: + # these_futures = [executor.submit(self.disperse_chunk, *args) for args in pool_args] + # concurrent.futures.wait(these_futures) + # disperse_chunk_output = [future.result() for future in these_futures] else: + disperse_chunk_output = [] for i in range(len(self.IDs)): - disperse_chunk_output = self.disperse_chunk(*pool_args[i]) + disperse_chunk_output.append(self.disperse_chunk(*pool_args[i])) t1 = time.time() log.info(f"Wall clock time for disperse_chunk order {order}: {(t1-t0):.1f} sec") @@ -225,7 +279,7 @@ def disperse_all(self, order, wmin, wmax, cache=False): self.simul_slits_sid.append(this_sid) @staticmethod - def disperse_chunk(c, order, wmin, wmax, sid, xs, ys, fluxes_dict, seg_wcs, grism_wcs, dims, extrapolate_sed, xoffset, yoffset): + def disperse_chunk(c, order, wmin, wmax, sens_waves, sens_resp, sid, xs, ys, fluxes_dict, seg_wcs, grism_wcs, dims, extrapolate_sed, xoffset, yoffset): """ Method that computes dispersion for a single source. To be called after create_pixel_list(). @@ -241,6 +295,10 @@ def disperse_chunk(c, order, wmin, wmax, sid, xs, ys, fluxes_dict, seg_wcs, gris Minimum wavelength for dispersed spectra wmax : float Maximum wavelength for dispersed spectra + sens_waves : float array + Wavelength array from photom reference file + sens_resp : float array + Response (flux calibration) array from photom reference file sid : int Source ID xs : np.ndarray @@ -288,8 +346,8 @@ def disperse_chunk(c, order, wmin, wmax, sid, xs, ys, fluxes_dict, seg_wcs, gris log.debug(f"source {sid} contains {len(xs)} pixels") all_res = [] for i in range(len(xs)): - # Here "i" indexes the pixel list for the object - # being processed, as opposed to the ID number of the object itself + # Here "i" indexes the pixel list for the segment + # being processed, as opposed to the ID number of the segment width = 1.0 height = 1.0 @@ -302,7 +360,7 @@ def disperse_chunk(c, order, wmin, wmax, sid, xs, ys, fluxes_dict, seg_wcs, gris ])) pars_i = (xc, yc, width, height, lams, fluxes, order, - wmin, wmax, + wmin, wmax, sens_waves, sens_resp, seg_wcs, grism_wcs, i, dims[::-1], 2, extrapolate_sed, xoffset, yoffset) with warnings.catch_warnings(): diff --git a/jwst/wfss_contam/wfss_contam.py b/jwst/wfss_contam/wfss_contam.py index 173d5145a4..087d63870a 100644 --- a/jwst/wfss_contam/wfss_contam.py +++ b/jwst/wfss_contam/wfss_contam.py @@ -1,17 +1,50 @@ -import matplotlib.pyplot as plt import logging import multiprocessing import numpy as np from stdatamodels.jwst import datamodels +from astropy.table import Table from .observations import Observation +from .sens1d import get_photom_data log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -def contam_corr(input_model, waverange, max_cores, n_sources=None, source_0=0): +def _determine_multiprocessing_ncores(max_cores): + + """Determine the number of cores to use for multiprocessing. + + Parameters + ---------- + max_cores : string + See docstring of contam_corr + + Returns + ------- + ncpus : int + Number of cores to use for multiprocessing + + """ + if max_cores == 'none': + ncpus = 1 + else: + num_cores = multiprocessing.cpu_count() + if max_cores == 'quarter': + ncpus = num_cores // 4 or 1 + elif max_cores == 'half': + ncpus = num_cores // 2 or 1 + elif max_cores == 'all': + ncpus = num_cores + else: + ncpus = 1 + log.debug(f"Found {num_cores} cores; using {ncpus}") + + return ncpus + + +def contam_corr(input_model, waverange, photom, max_cores, brightest_n=None): """ The main WFSS contamination correction function @@ -21,18 +54,18 @@ def contam_corr(input_model, waverange, max_cores, n_sources=None, source_0=0): Input data model containing 2D spectral cutouts waverange : `~jwst.datamodels.WavelengthrangeModel` Wavelength range reference file model + photom : `~jwst.datamodels.NrcWfssPhotomModel` or `~jwst.datamodels.NisWfssPhotomModel` + Photom (flux cal) reference file model max_cores : string Number of cores to use for multiprocessing. If set to 'none' (the default), then no multiprocessing will be done. The other allowable values are 'quarter', 'half', and 'all', which indicate the fraction of cores to use for multi-proc. The total number of cores includes the SMT cores (Hyper Threading for Intel). - n_sources : int + brightest_n : int Number of sources to simulate. If None, then all sources in the - input model will be simulated. This is primarily useful for testing. - source_0 : int - Source ID to start with when selecting sources to simulate. This - is primarily useful for testing. + input model will be simulated. Requires loading the source catalog + file if not None. Returns ------- @@ -45,28 +78,13 @@ def contam_corr(input_model, waverange, max_cores, n_sources=None, source_0=0): """ - # Determine number of cpu's to use for multi-processing - if max_cores == 'none': - ncpus = 1 - else: - num_cores = multiprocessing.cpu_count() - if max_cores == 'quarter': - ncpus = num_cores // 4 or 1 - elif max_cores == 'half': - ncpus = num_cores // 2 or 1 - elif max_cores == 'all': - ncpus = num_cores - else: - ncpus = 1 - log.debug(f"Found {num_cores} cores; using {ncpus}") + ncpus = _determine_multiprocessing_ncores(max_cores) # Initialize output model output_model = input_model.copy() - # Get the segmentation map for this grism exposure + # Get the segmentation map, direct image for this grism exposure seg_model = datamodels.open(input_model.meta.segmentation_map) - - # Get the direct image from which the segmentation map was constructed direct_file = input_model.meta.direct_image image_names = [direct_file] log.debug(f"Direct image names={image_names}") @@ -100,37 +118,37 @@ def contam_corr(input_model, waverange, max_cores, n_sources=None, source_0=0): else: filter_name = filter_kwd - # Load lists of wavelength ranges and flux cal info for all orders - wmin = {} - wmax = {} - for order in spec_orders: - wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order]) - wmin[order] = wavelength_range[order][0] - wmax[order] = wavelength_range[order][1] - log.debug(f"wmin={wmin}, wmax={wmax}") - - # for testing, select a subset of the brightest sources, as extracted in extract2d - ids_in_extract2d = np.array([slit.source_id for slit in output_model.slits]) - good = (ids_in_extract2d >= source_0) - selected_IDs = list(ids_in_extract2d[good])[:n_sources] - simul_all = None + # select a subset of the brightest sources using source catalog + if brightest_n is not None: + source_catalog = Table.read(input_model.meta.source_catalog, format='ascii.ecsv') + source_catalog.sort("isophotal_abmag", reverse=False) #magnitudes in ascending order, since brighter is smaller mag number + selected_IDs = list(source_catalog["label"])[:brightest_n] + else: + selected_IDs = None + obs = Observation(image_names, seg_model, grism_wcs, filter_name, boundaries=[0, 2047, 0, 2047], offsets=[xoffset, yoffset], max_cpu=ncpus, ID=selected_IDs) good_slits = [slit for slit in output_model.slits if slit.source_id in obs.IDs] - #output_model.slits = good_slits #not sure why, but this fails to index properly output_model = datamodels.MultiSlitModel() output_model.slits.extend(good_slits) - log.info(f"Simulating only the first {n_sources} sources starting at index {source_0}") + log.info(f"Simulating only the first {brightest_n} sources") + - # Create simulated grism image for each order and sum them up + simul_all = None for order in spec_orders: - log.info(f"Creating full simulated grism image for order {order}") - obs.disperse_all(order, wmin[order], wmax[order]) + # Load lists of wavelength ranges and flux cal info + wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order]) + wmin = wavelength_range[order][0] + wmax = wavelength_range[order][1] + log.debug(f"wmin={wmin}, wmax={wmax} for order {order}") + sens_waves, sens_response = get_photom_data(photom, filter_kwd, pupil_kwd, order) - # Accumulate result for this order into the combined image + # Create simulated grism image for each order and sum them up + log.info(f"Creating full simulated grism image for order {order}") + obs.disperse_all(order, wmin, wmax, sens_waves, sens_response) if simul_all is None: simul_all = obs.simulated_image else: @@ -142,8 +160,6 @@ def contam_corr(input_model, waverange, max_cores, n_sources=None, source_0=0): # save the simulation multislitmodel obs.simul_slits.save("simulated_slits.fits", overwrite=True) - - # need to re-make these now that I changed disperse_chunk simul_slit_sids = np.array(obs.simul_slits_sid) simul_slit_orders = np.array(obs.simul_slits_order) @@ -159,19 +175,18 @@ def contam_corr(input_model, waverange, max_cores, n_sources=None, source_0=0): order = slit.meta.wcsinfo.spectral_order good = (simul_slit_sids == sid) * (simul_slit_orders == order) if not any(good): + log.warning(f"Source {sid} order {order} requested by input slit model \ + but not found in simulated slits") continue else: print('Subtracting contamination for source', sid, 'order', order) - good_idx = np.where(good)[0][0] this_simul = obs.simul_slits.slits[good_idx] - + # cut out this source's contamination from the full simulated image fullframe_sim = np.zeros(obs.dims) y0 = this_simul.ystart x0 = this_simul.xstart - #print(obs.dims, this_simul.data.shape, slit.data.shape) - #print(y0, x0) fullframe_sim[y0:y0 + this_simul.ysize, x0:x0 + this_simul.xsize] = this_simul.data contam = simul_all - fullframe_sim diff --git a/jwst/wfss_contam/wfss_contam_step.py b/jwst/wfss_contam/wfss_contam_step.py index b53057410c..1b41fb468f 100755 --- a/jwst/wfss_contam/wfss_contam_step.py +++ b/jwst/wfss_contam/wfss_contam_step.py @@ -22,7 +22,7 @@ class WfssContamStep(Step): skip = boolean(default=True) """ - reference_file_types = ['wavelengthrange'] + reference_file_types = ['photom', 'wavelengthrange'] def process(self, input_model, *args, **kwargs): @@ -35,11 +35,16 @@ def process(self, input_model, *args, **kwargs): self.log.info(f'Using WAVELENGTHRANGE reference file {waverange_ref}') waverange_model = datamodels.WavelengthrangeModel(waverange_ref) + # Get the photom ref file + photom_ref = self.get_reference_file(dm, 'photom') + self.log.info(f'Using PHOTOM reference file {photom_ref}') + photom_model = datamodels.open(photom_ref) + result, simul, contam = wfss_contam.contam_corr(dm, waverange_model, + photom_model, max_cores, - n_sources=12, - source_0=0) + brightest_n=150) # Save intermediate results, if requested if self.save_simulated_image: From 89e63388b377709bbe455f8c021892fa66412cfc Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Tue, 9 Apr 2024 17:06:50 -0400 Subject: [PATCH 05/11] added some unit tests --- jwst/wfss_contam/disperse.py | 62 +- jwst/wfss_contam/observations.py | 74 +- jwst/wfss_contam/tests/__init__.py | 0 jwst/wfss_contam/tests/data/__init__.py | 0 jwst/wfss_contam/tests/data/grism_wcs.asdf | 1180 +++++++++++++++++ .../tests/data/segmentation_wcs.asdf | Bin 0 -> 4767 bytes jwst/wfss_contam/tests/test_disperse.py | 46 + jwst/wfss_contam/tests/test_observations.py | 191 +++ jwst/wfss_contam/tests/test_wfss_contam.py | 13 + jwst/wfss_contam/wfss_contam.py | 18 +- jwst/wfss_contam/wfss_contam_step.py | 7 +- 11 files changed, 1533 insertions(+), 58 deletions(-) create mode 100644 jwst/wfss_contam/tests/__init__.py create mode 100644 jwst/wfss_contam/tests/data/__init__.py create mode 100644 jwst/wfss_contam/tests/data/grism_wcs.asdf create mode 100644 jwst/wfss_contam/tests/data/segmentation_wcs.asdf create mode 100644 jwst/wfss_contam/tests/test_disperse.py create mode 100644 jwst/wfss_contam/tests/test_observations.py create mode 100644 jwst/wfss_contam/tests/test_wfss_contam.py diff --git a/jwst/wfss_contam/disperse.py b/jwst/wfss_contam/disperse.py index 4ae6f521dc..8a4ecde4a0 100644 --- a/jwst/wfss_contam/disperse.py +++ b/jwst/wfss_contam/disperse.py @@ -7,6 +7,44 @@ from .sens1d import create_1d_sens +def interpolate_fluxes(lams, flxs, extrapolate_sed): + ''' + Parameters + ---------- + lams : float array + Array of wavelengths corresponding to the fluxes (flxs) for each pixel. + One wavelength per direct image, so can be a single value. + flxs : float array + Array of fluxes (flam) for the pixels contained in x0, y0. If a single + direct image is in use, this will be a single value. + extrapolate_sed : bool + Whether to allow for the SED of the object to be extrapolated when it does not fully cover the + needed wavelength range. Default if False. + + Returns + ------- + flux : function + Function that returns the flux at a given wavelength. If only one direct image is in use, this + function will always return the same value + ''' + + if len(lams) > 1: + # If we have direct image flux values from more than one filter (lambda), + # we have the option to extrapolate the fluxes outside the + # wavelength range of the direct images + if extrapolate_sed is False: + return interp1d(lams, flxs, fill_value=0., bounds_error=False) + else: + return interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False) + else: + # If we only have flux from one lambda, just use that + # single flux value at all wavelengths + def flux(x): + return flxs[0] + return flux + + + def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, oversample_factor=2, extrapolate_sed=False, xoffset=0, @@ -84,20 +122,8 @@ def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, sky_to_imgxy = grism_wcs.get_transform('world', 'detector') imgxy_to_grismxy = grism_wcs.get_transform('detector', 'grism_detector') - # Setup function for retrieving flux values at each dispersed wavelength - if len(lams) > 1: - # If we have direct image flux values from more than one filter (lambda), - # we have the option to extrapolate the fluxes outside the - # wavelength range of the direct images - if extrapolate_sed is False: - flux = interp1d(lams, flxs, fill_value=0., bounds_error=False) - else: - flux = interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False) - else: - # If we only have flux from one lambda, just use that - # single flux value at all wavelengths - def flux(x): - return flxs[0] + # Set up function for retrieving flux values at each dispersed wavelength + flux = interpolate_fluxes(lams, flxs, extrapolate_sed) # Get x/y positions in the grism image corresponding to wmin and wmax: # Start with RA/Dec of the input pixel position in segmentation map, @@ -116,11 +142,11 @@ def flux(x): # Use a natural wavelength scale or the wavelength scale of the input SED/spectrum, # whichever is smaller, divided by oversampling requested - input_dlam = np.median(lams[1:] - lams[:-1]) - if input_dlam < dw: - dlam = input_dlam / oversample_factor + if len(lams) > 1: + input_dlam = np.median(lams[1:] - lams[:-1]) + if input_dlam < dw: + dlam = input_dlam / oversample_factor else: - # this value gets used when we only have 1 direct image wavelength dlam = dw / oversample_factor # Create list of wavelengths on which to compute dispersed pixels diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index b27d3e71c2..8e75648e35 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -1,7 +1,6 @@ import time import numpy as np import multiprocessing as mp -import concurrent.futures from scipy import sparse @@ -52,17 +51,48 @@ def background_subtract(data, box_size=None, filter_size=(3,3), sigma=3.0, exclu box_size = (int(data.shape[0]/5), int(data.shape[1]/5)) sigma_clip = SigmaClip(sigma=sigma) bkg_estimator = MedianBackground() - bkg = Background2D(data, (500, 500), filter_size=filter_size, + bkg = Background2D(data, box_size, filter_size=filter_size, sigma_clip=sigma_clip, bkg_estimator=bkg_estimator, exclude_percentile=exclude_percentile) return data - bkg.background +def _select_ids(ID, all_IDs): + ''' + Select the source IDs to be processed based on the input ID parameter. + + Parameters + ---------- + ID : int or list-like + ID(s) of source to process. If None, all sources processed. + all_IDs : np.ndarray + Array of all source IDs in the segmentation map + + Returns + ------- + selected_IDs : list + List of selected source IDs + ''' + if ID is None: + log.info(f"Loading all {len(all_IDs)} sources from segmentation map") + return all_IDs + + elif isinstance(ID, int): + log.info(f"Loading single source {ID} from segmentation map") + return [ID] + + elif isinstance(ID, list) or isinstance(ID, np.ndarray): + log.info(f"Loading {len(ID)} of {len(all_IDs)} selected sources from segmentation map") + return list(ID) + else: + raise ValueError("ID must be an integer or a list of integers") + + class Observation: """This class defines an actual observation. It is tied to a single grism image.""" - def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, + def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=None, sed_file=None, extrapolate_sed=False, boundaries=[], offsets=[0, 0], renormalize=True, max_cpu=1): @@ -99,9 +129,10 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, self.seg_wcs = segmap_model.meta.wcs self.grism_wcs = grism_wcs self.ID = ID - self.IDs = [] self.dir_image_names = direct_images self.seg = segmap_model.data + all_ids = np.array(list(set(np.ravel(self.seg)))) + self.IDs = _select_ids(ID, all_ids) self.filter = filter self.sed_file = sed_file # should always be NONE for baseline pipeline (use flat SED) self.cache = False @@ -114,9 +145,9 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, if len(boundaries) == 0: log.debug("No boundaries passed.") self.xstart = 0 - self.xend = self.xstart + self.dims[0] - 1 + self.xend = self.xstart + self.seg.shape[0] - 1 self.ystart = 0 - self.yend = self.ystart + self.dims[1] - 1 + self.yend = self.ystart + self.seg.shape[1] - 1 else: self.xstart, self.xend, self.ystart, self.yend = boundaries self.dims = (self.yend - self.ystart + 1, self.xend - self.xstart + 1) @@ -136,26 +167,10 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, self.simul_slits_sid = [] def create_pixel_list(self): - # Create a list of pixels to be dispersed, grouped per object ID. - - if self.ID == 0: - # When ID=0, all sources in the segmentation map are processed. - # This creates a huge list of all x,y pixel indices that have non-zero values - # in the seg map, sorted by those indices belonging to a particular source ID. - all_IDs = np.array(list(set(np.ravel(self.seg)))) - all_IDs = all_IDs[all_IDs > 0] - self.IDs = all_IDs - log.info(f"Loading {len(all_IDs)} sources from segmentation map") - elif isinstance(self.ID, int): - # Process only the given source ID - log.info(f"Loading source {self.ID} from segmentation map") - self.IDs = [self.ID] - elif isinstance(self.ID, (list, np.array)): - # Process only the given list of source IDs - log.info(f"Loading {len(self.ID)} of {len(list(set(np.ravel(self.seg))))} selected sources from segmentation map") - self.IDs = self.ID - else: - raise ValueError("ID must be an integer or a list of integers") + ''' + Create a list of pixels to be dispersed, grouped per object ID. + When ID is None, all sources in the segmentation map are processed. + ''' self.xs = [] self.ys = [] @@ -257,10 +272,7 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): ctx = mp.get_context("forkserver") with ctx.Pool(self.max_cpu) as mypool: disperse_chunk_output = mypool.starmap(self.disperse_chunk, pool_args) - #with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_cpu) as executor: - # these_futures = [executor.submit(self.disperse_chunk, *args) for args in pool_args] - # concurrent.futures.wait(these_futures) - # disperse_chunk_output = [future.result() for future in these_futures] + else: disperse_chunk_output = [] for i in range(len(self.IDs)): @@ -363,6 +375,8 @@ def disperse_chunk(c, order, wmin, wmax, sens_waves, sens_resp, sid, xs, ys, flu wmin, wmax, sens_waves, sens_resp, seg_wcs, grism_wcs, i, dims[::-1], 2, extrapolate_sed, xoffset, yoffset) + if i == 0: + print(pars_i) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in scalar divide") warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice") diff --git a/jwst/wfss_contam/tests/__init__.py b/jwst/wfss_contam/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jwst/wfss_contam/tests/data/__init__.py b/jwst/wfss_contam/tests/data/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jwst/wfss_contam/tests/data/grism_wcs.asdf b/jwst/wfss_contam/tests/data/grism_wcs.asdf new file mode 100644 index 0000000000..8334394f81 --- /dev/null +++ b/jwst/wfss_contam/tests/data/grism_wcs.asdf @@ -0,0 +1,1180 @@ +#ASDF 1.0.0 +#ASDF_STANDARD 1.5.0 +%YAML 1.1 +%TAG ! tag:stsci.edu:asdf/ +--- !core/asdf-1.1.0 +asdf_library: !core/software-1.0.0 {author: The ASDF Developers, homepage: 'http://github.com/asdf-format/asdf', + name: asdf, version: 3.1.1.dev2+g15e830d} +history: + extensions: + - !core/extension_metadata-1.0.0 + extension_class: asdf.extension._manifest.ManifestExtension + extension_uri: asdf://asdf-format.org/astronomy/gwcs/extensions/gwcs-1.2.0 + software: !core/software-1.0.0 {name: gwcs, version: 0.21.0} + - !core/extension_metadata-1.0.0 + extension_class: asdf.extension._manifest.ManifestExtension + extension_uri: asdf://asdf-format.org/astronomy/coordinates/extensions/coordinates-1.0.0 + software: !core/software-1.0.0 {name: asdf-astropy, version: 0.5.0} + - !core/extension_metadata-1.0.0 + extension_class: asdf.extension._manifest.ManifestExtension + extension_uri: asdf://asdf-format.org/core/extensions/core-1.5.0 + software: !core/software-1.0.0 {name: asdf, version: 3.1.1.dev2+g15e830d} + - !core/extension_metadata-1.0.0 + extension_class: asdf.extension._manifest.ManifestExtension + extension_uri: asdf://asdf-format.org/transform/extensions/transform-1.5.0 + software: !core/software-1.0.0 {name: asdf-astropy, version: 0.5.0} + - !core/extension_metadata-1.0.0 + extension_class: asdf_astropy._manifest.CompoundManifestExtension + extension_uri: asdf://astropy.org/core/extensions/core-1.5.0 + software: !core/software-1.0.0 {name: asdf-astropy, version: 0.5.0} + - !core/extension_metadata-1.0.0 + extension_class: asdf.extension._manifest.ManifestExtension + extension_uri: asdf://stsci.edu/jwst_pipeline/extensions/jwst_transforms-1.0.0 + software: !core/software-1.0.0 {name: stdatamodels, version: 1.10.1} +wcs: ! + name: '' + pixel_shape: null + steps: + - ! + frame: ! + axes_names: [x_grism, y_grism] + axes_order: [0, 1] + axis_physical_types: ['custom:x_grism', 'custom:y_grism'] + name: grism_detector + unit: [!unit/unit-1.0.0 pixel, !unit/unit-1.0.0 pixel] + transform: !transform/compose-1.2.0 + bounding_box: !transform/property/bounding_box-1.0.0 + ignore: [] + intervals: + x0: [-0.5, 319.5] + x1: [-0.5, 340.5] + order: C + forward: + - !transform/compose-1.2.0 + forward: + - !transform/remap_axes-1.3.0 + inputs: [x0, x1] + mapping: [0, 1, 0, 0, 0] + outputs: [x0, x1, x2, x3, x4] + - !transform/concatenate-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/shift-1.2.0 + inputs: [x] + offset: 122.0 + outputs: [y] + - !transform/shift-1.2.0 + inputs: [x] + offset: 1031.0 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + - !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + inverse: !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + outputs: [y] + value: 482.26565028841355 + outputs: [y] + value: 482.26565028841355 + inputs: [x0, x1, x] + outputs: [y0, y1, y] + - !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + inverse: !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + outputs: [y] + value: 1205.025009833007 + outputs: [y] + value: 1205.025009833007 + inputs: [x00, x10, x0, x1] + outputs: [y00, y10, y0, y1] + - !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + inverse: !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + outputs: [y] + value: 1.0 + outputs: [y] + value: 1.0 + inputs: [x00, x10, x0, x1, x] + outputs: [y00, y10, y0, y1, y] + inputs: [x0, x1] + outputs: [y00, y10, y0, y1, y] + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - ! + inputs: [x, y, x0, y0, order] + inverse: ! + inputs: [x, y, wavelength, order] + lmodels: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 35 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 36 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 37 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 38 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 39 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + model_type: NIRISSBackwardGrismDispersion + name: niriss_backward_grism_dispersion + orders: [1, 2, 3, -1, 0] + outputs: [x, y, x0, y0, order] + theta: 0.006000000000653927 + xmodels: + - - &id001 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 0 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id002 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 1 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id003 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 2 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id004 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 3 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id005 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 4 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id006 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 5 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id007 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 6 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id008 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 7 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id009 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 8 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id010 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 9 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id011 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 10 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id012 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 11 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id013 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 12 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id014 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 13 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id015 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 14 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + ymodels: + - - &id016 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 15 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id017 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 16 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id018 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 17 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id019 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 18 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id020 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 19 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id021 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 20 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id022 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 21 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id023 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 22 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id024 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 23 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id025 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 24 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id026 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 25 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id027 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 26 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id028 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 27 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id029 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 28 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id030 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 29 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + lmodels: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 30 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 31 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 32 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 33 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 34 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + model_type: NIRISSForwardRowGrismDispersion + name: niriss_forward_row_grism_dispersion + orders: [1, 2, 3, -1, 0] + outputs: [x, y, wavelength, order] + theta: -0.006000000000653927 + xmodels: + - - *id001 + - *id002 + - *id003 + - - *id004 + - *id005 + - *id006 + - - *id007 + - *id008 + - *id009 + - - *id010 + - *id011 + - *id012 + - - *id013 + - *id014 + - *id015 + ymodels: + - - *id016 + - *id017 + - *id018 + - - *id019 + - *id020 + - *id021 + - - *id022 + - *id023 + - *id024 + - - *id025 + - *id026 + - *id027 + - - *id028 + - *id029 + - *id030 + - !transform/remap_axes-1.3.0 + inputs: [x0, x1, x2, x3] + mapping: [0, 1, 2, 3] + outputs: [x0, x1, x2, x3] + inputs: [x, y, x0, y0, order] + outputs: [x0, x1, x2, x3] + - !transform/concatenate-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/identity-1.2.0 + inputs: [x0, x1] + n_dims: 2 + outputs: [x0, x1] + - !transform/multiply-1.2.0 + forward: + - !transform/identity-1.2.0 + inputs: [x0] + outputs: [x0] + - !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + name: velocity_correction + outputs: [y] + value: 0.99999381609348 + inputs: [x0] + inverse: !transform/divide-1.2.0 + forward: + - !transform/identity-1.2.0 + inputs: [x0] + outputs: [x0] + - !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + name: inv_vel_correction + outputs: [y] + value: 0.99999381609348 + inputs: [x0] + outputs: [x0] + outputs: [x0] + inputs: [x00, x10, x01] + outputs: [x00, x10, x01] + - !transform/identity-1.2.0 + inputs: [x0] + outputs: [x0] + inputs: [x00, x10, x01, x0] + outputs: [x00, x10, x01, x0] + inputs: [x, y, x0, y0, order] + outputs: [x00, x10, x01, x0] + inputs: [x0, x1] + outputs: [x00, x10, x01, x0] + - ! + frame: ! + frames: + - ! + axes_names: [x, y] + axes_order: [0, 1] + axis_physical_types: ['custom:x', 'custom:y'] + name: detectorspatial + unit: [!unit/unit-1.0.0 pixel, !unit/unit-1.0.0 pixel] + - &id033 ! + axes_names: [wavelength] + axes_order: [2] + axis_physical_types: [em.wl] + name: spectral + unit: [!unit/unit-1.0.0 um] + name: detector + transform: !transform/concatenate-1.2.0 + forward: + - !transform/compose-1.2.0 + bounding_box: !transform/property/bounding_box-1.0.0 + ignore: [] + intervals: + x0: [-0.5, 2047.5] + x1: [-0.5, 2047.5] + order: C + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/shift-1.2.0 + inputs: [x] + offset: 2.119 + outputs: [y] + - !transform/shift-1.2.0 + inputs: [x] + offset: -1.0476 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/shift-1.2.0 + inputs: [x] + offset: -1023.5 + outputs: [y] + - !transform/shift-1.2.0 + inputs: [x] + offset: -1023.5 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + - &id031 !transform/remap_axes-1.3.0 + inputs: [x0, x1] + inverse: !transform/identity-1.2.0 + inputs: [x0, x1] + n_dims: 2 + outputs: [x0, x1] + mapping: [0, 1, 0, 1] + outputs: [x0, x1, x2, x3] + inputs: [x0, x1] + outputs: [x0, x1, x2, x3] + - !transform/concatenate-1.2.0 + forward: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 40 + datatype: float64 + byteorder: little + shape: [6, 6] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 41 + datatype: float64 + byteorder: little + shape: [6, 6] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + inputs: [x0, y0, x1, y1] + inverse: !transform/concatenate-1.2.0 + forward: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 42 + datatype: float64 + byteorder: little + shape: [6, 6] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 43 + datatype: float64 + byteorder: little + shape: [6, 6] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + inputs: [x0, y0, x1, y1] + outputs: [z0, z1] + outputs: [z0, z1] + inputs: [x0, x1] + outputs: [z0, z1] + - &id032 !transform/identity-1.2.0 + inputs: [x0, x1] + inverse: !transform/remap_axes-1.3.0 + inputs: [x0, x1] + mapping: [0, 1, 0, 1] + outputs: [x0, x1, x2, x3] + n_dims: 2 + outputs: [x0, x1] + inputs: [x0, x1] + outputs: [x0, x1] + - *id031 + inputs: [x0, x1] + outputs: [x0, x1, x2, x3] + - !transform/concatenate-1.2.0 + forward: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 44 + datatype: float64 + byteorder: little + shape: [2, 2] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 45 + datatype: float64 + byteorder: little + shape: [2, 2] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + inputs: [x0, y0, x1, y1] + inverse: !transform/concatenate-1.2.0 + forward: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 46 + datatype: float64 + byteorder: little + shape: [2, 2] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 47 + datatype: float64 + byteorder: little + shape: [2, 2] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + inputs: [x0, y0, x1, y1] + outputs: [z0, z1] + outputs: [z0, z1] + inputs: [x0, x1] + outputs: [z0, z1] + - *id032 + inputs: [x0, x1] + outputs: [x0, x1] + - !transform/concatenate-1.2.0 + forward: + - !transform/shift-1.2.0 + inputs: [x] + offset: -291.141 + outputs: [y] + - !transform/shift-1.2.0 + inputs: [x] + offset: -698.015 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + inputs: [x0, x1] + outputs: [y0, y1] + inputs: [x0, x1] + outputs: [y0, y1] + - !transform/identity-1.2.0 + inputs: [x0, x1] + n_dims: 2 + outputs: [x0, x1] + inputs: [x00, x10, x01, x11] + outputs: [y0, y1, x0, x1] + - ! + frame: ! + frames: + - ! + axes_names: [v2, v3] + axes_order: [0, 1] + axis_physical_types: ['custom:v2', 'custom:v3'] + name: v2v3spatial + unit: [!unit/unit-1.0.0 arcsec, !unit/unit-1.0.0 arcsec] + - *id033 + name: v2v3 + transform: !transform/concatenate-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/scale-1.2.0 + factor: 0.9999939001894596 + inputs: [x] + name: dva_scale_v2 + outputs: [y] + - !transform/scale-1.2.0 + factor: 0.9999939001894596 + inputs: [x] + name: dva_scale_v3 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + - !transform/concatenate-1.2.0 + forward: + - !transform/shift-1.2.0 + inputs: [x] + name: dva_v2_shift + offset: -0.0017759049405570732 + outputs: [y] + - !transform/shift-1.2.0 + inputs: [x] + name: dva_v3_shift + offset: -0.004257759254392014 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + inputs: [x0, x1] + name: DVA_Correction + outputs: [y0, y1] + - !transform/identity-1.2.0 + inputs: [x0, x1] + n_dims: 2 + outputs: [x0, x1] + inputs: [x00, x10, x01, x11] + outputs: [y0, y1, x0, x1] + - ! + frame: ! + frames: + - ! + axes_names: [v2, v3] + axes_order: [0, 1] + axis_physical_types: ['custom:v2', 'custom:v3'] + name: v2v3vacorrspatial + unit: [!unit/unit-1.0.0 arcsec, !unit/unit-1.0.0 arcsec] + - *id033 + name: v2v3vacorr + transform: !transform/concatenate-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/scale-1.2.0 + factor: 0.0002777777777777778 + inputs: [x] + outputs: [y] + - !transform/scale-1.2.0 + factor: 0.0002777777777777778 + inputs: [x] + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + - ! + inputs: [lon, lat] + outputs: [x, y, z] + transform_type: spherical_to_cartesian + wrap_lon_at: 180 + inputs: [x0, x1] + outputs: [x, y, z] + - !transform/rotate_sequence_3d-1.0.0 + angles: [-0.0808725, 0.19389305555555555, 196.1037680531535, 65.83768326569226, + -260.82488865179715] + axes_order: zyxyz + inputs: [x, y, z] + outputs: [x, y, z] + rotation_type: cartesian + inputs: [x0, x1] + outputs: [x, y, z] + - ! + inputs: [x, y, z] + outputs: [lon, lat] + transform_type: cartesian_to_spherical + wrap_lon_at: 360 + inputs: [x0, x1] + name: v23tosky + outputs: [lon, lat] + - !transform/identity-1.2.0 + inputs: [x0, x1] + n_dims: 2 + outputs: [x0, x1] + inputs: [x00, x10, x01, x11] + outputs: [lon, lat, x0, x1] + - ! + frame: ! + frames: + - ! + axes_names: [lon, lat] + axes_order: [0, 1] + axis_physical_types: [pos.eq.ra, pos.eq.dec] + name: sky + reference_frame: ! + frame_attributes: {} + unit: [!unit/unit-1.0.0 deg, !unit/unit-1.0.0 deg] + - *id033 + name: world + transform: null +... +ÓBLK0HHH°_ÏnHÈÂØä¢LdÐîcÙ=yŒO@Íîuµ‰Q?*œŸÎþŒš¾,U ÑÍpf¿=)Ðc{ïŒ>uF.ìZµ>ÓBLK0HHHï.ö+•`1Ë›XMVäñŸOjMtÀ«ž¿ø¦s¿áÿ+†i–§>ʳDÇ rr?ÒàÓb¾“¥>Jý…¹òžÓBLK0HHH—ìêZĬFžõõ$}O‘X +7ÏÇëÒ?=î AïÎH?V§®Sto¾uf%'¿·\¿tÄÄqñó¡¾Ü™Ë˜çZ²>ÓBLK0HHHx¸"ùü~q'oàFûÑ 'kÔC4ÚVÀ<±Ì÷Ü-¿¥æ“ª¥;©¾ÞQTŸëc?MŽb”¶‡>˜êþ¼· ¿¾ÓBLK0HHHù8Þ0à$zÝ—þYxJ¹ªSDúíë@P„ÀsëjxÄ ¿Õž°Úl¦>EapýV}¿+ŽÍÿÀ§Ã¾È‰¸ + Û>ÓBLK0HHHÒ9æù×8óVØI€ãi˜kÑ´-þ¿ÇSkn@Üq?ç,°zþDîøL.z?­GˆÃå:¬¾åùø# ϾÓBLK0HHH€²`Ôl^vÍCq`ròó‘ˆc]ÜFçjÀˆ”°ã‘?¡‡ça;ݾoÀBA„º¿ž2溾ԣ¼]Žš?ÓBLK0HHHúЂ£ßï‡#ô`¨xœ#ÑóD‹lçû÷À²á]쳿¬_³2ñø>žY·æ3ÅÒ?óÒt¹Ué>Ò¼¬Øs"¿ÓBLK0HHHžôî°ðÃ÷y­Ë³ßy¸éÏ~¤ZQ@—×øª?h¼$º]¢ð¾~ìäÂ7Ê¿‰cħÖë¾[\¿¶€D?ÓBLK0HHHO8,òµnÄHðêï& ŠO3f÷äa¡—w@J_ý¡VæO?)ÃMÿ¦ª>¨ÃQŒÓs?Q"‰È—s—¾q²P˜T3½¾ÓBLK0HHHŒkh.¿Ž )½ ]AóÞÝ$•Út@P0Jm5-‘?i¡6m½Û¾}‰Âx9¿ëµÿ² V®> +$~±×šë>ÓBLK0HHH/¢l{ýk²÷5Ò•ûy4¸­-ÜÀc»V3cˆ¿°˜ñ31UÔ>&JJ“µ–?%¿"Q°>¡ã¡2°\ç¾ÓBLK0HHHp¤zó©¡kÏìðc–PtY\Âõ¸ŠÑÀR·éS·p?jXNu ë>žuðÀ¾¥¿XY‘$ù¾!—Ѿ·I?ÓBLK0HHH ¹êQ4ðNf½õùR¸…ëÊyÀWÑÝ}§,?T !° +… ?ÛÆ]1á­¿©)¬”­Ÿ¿ìý¦­š²&?ÓBLK0HHH¾…R¥„q3:IWw7F5š™™™ýKò@CÁR~ÎŒ?®¶6ep ? QðÚÆ®¿º'µ¿LǦÙ_'?ÓBLK0HHHÐMyŒ0g”¶? ˜nÉSX© ¢jø¿yøõo¿ŸÐ¾Jø„ãL>eÿY(Àrå>çXèÞ ¾Ç7®úX4¾ÓBLK0HHH¾,½½dƒë ý3.DMÒP¦ï5Çó?–@”»œÔò>¹Îhy£&¾ìTU…±ñ¾€£èdìÊ#¾±äGAöD>ÓBLK0HHH^›l€ÆáH”.‰’í\’2Ëø÷¿MßWy +²É¾ÊÃ÷EYîð=Óqõ“³´Û>H9•‹eÈ >’ÌÓh£1¾ÓBLK0HHHŽê$ÄÑARÓ# kà«ìQ¸…ëÁ?„ÈÆvÊļkÕ¶ÿË­<ÔxS-³Á¼~RÃÞò <# eé6<ÓBLK0HHHóê6°Hî8ÛiX Gò¾ú’•RºÈy½‘åYYÙ<á“ÆÞQ€¼df•BšÕ<¢07~¤!¼p—änf¼ÓBLK0HHH0¡”ýXãx£ç‡©ll?éî:o=•ƒKx©°Î¼¹cr–© +<·%÷Ž1ʼ€IøÓ]<€Ò†²<ÓBLK0HHHxÁ§¬å¡?§¯ é¦(Üóüi#@Nç<Ù^¼?8 øëòâ]¾Ðü¸€-¿éËdÓE¾‡ÐtøKP|>ÓBLK0HHHeN¨SZÒ˜äGG¤-To+½6{ÀèÅÄ 2¿™‚ä|ilw>GâÉ°E?”„ùÔ?Ói>' + (XÍ•¾ÓBLK0HHHôÕÚ-—­&?ִŠͅÒrû@ÍÑ?ŸOæÜ;œ(?thÄoŠ“o¾®6 :o>¿Š¢%žh¾BÅòƒ>ÓBLK0HHHü8 ¤=¡wÀ/¤ÎD‚Ë"§¥gz‰±,ÀavN?ä¾H¥A@þ@¾ßµ‚•%¿¬çE8”V$>×G½fqúQ>ÓBLK0HHHù‡xÓòØð¯üÁ—yÈæÛàDôk« +À‚8Ú&¿ajOnq>Œi‰I2?×£摯9¾™–î³²l¾ÓBLK0HHH÷+OKV/ë£÷_„ñs“%߶~×¥?ß. +êóK? S|i¾¢Y$v,¿Ê3׬NI¾˜^ÝÌQ“}>ÓBLK0HHH‹µbX"¿ÄÁæì\ŽM #R¸…ëQÀ9l¤Cä½s-Ø£4=b!ˆ|jFÚ=P.|Åi§¼D½ðP¨Æ$½ÓBLK0HHHX…* í«0§ådòì`¸ÓPf˜7%¾Ö¹d±¾hòë§îU=JÆCëÔú=ù ¸Û-´Ç¼ó‡ðìµ6E½ÓBLK0HHH•G¾þRÁÐÇnì¾Øä’Öë£X6p%¾%ÐÕ¯š!¾þã¼H$…U=5ø·…Édû=zûâñ/ÿǼʟR¾©E½ÓBLK0_“Ä—V˱†úþ“+L+è?ÍÌÌÌÌÌø?ÓBLK0_“Ä—V˱†úþ“+L+è?ÍÌÌÌÌÌø?ÓBLK0_“Ä—V˱†úþ“+L+è?ÍÌÌÌÌÌø?ÓBLK0_“Ä—V˱†úþ“+L+è?ÍÌÌÌÌÌø?ÓBLK0_“Ä—V˱†úþ“+L+è?ÍÌÌÌÌÌø?ÓBLK0 + ãºÜ"ZÞUÊ`àZéc÷Þ{ï½÷Þ¿¥”RJ)¥ä?ÓBLK0 + ãºÜ"ZÞUÊ`àZéc÷Þ{ï½÷Þ¿¥”RJ)¥ä?ÓBLK0 + ãºÜ"ZÞUÊ`àZéc÷Þ{ï½÷Þ¿¥”RJ)¥ä?ÓBLK0 + ãºÜ"ZÞUÊ`àZéc÷Þ{ï½÷Þ¿¥”RJ)¥ä?ÓBLK0 + ãºÜ"ZÞUÊ`àZéc÷Þ{ï½÷Þ¿¥”RJ)¥ä?ÓBLK0   Œn¥Ž ¶^ä·ß‰»îHfÇqDZ;Œ…¼¦ß{ç½úÍ@“•½Æ¥z¶ÐÐ<êžóp‘/<¸rP$ะ?«©°K>&jIáN ½³09ÊFu½?Ø«Õ0o*\›Tvܽú.„‡Ûs&½ Ë^Öq‡Šä“2+±½S‹®½A§À)³B<ÝÏ1º[p=jòûËТé¼%ÈWújE`<ؼãäúÙļ5ù?U¨‚¼õk”*ÿ=¼ÓBLK0   ÏÒ ’Ô1|þþk1Lœ†ÏŸl‚>.%jGýH‘>ßX +ú¾YÃ×f¦½¹‘T-ž.@êáIBø¿êCŽ5«R>Ï–‹~ʽ\,ÓñöW <->¡°iÎ^ ä½v\sdd>^K©Ú_ +¾ ®Ã d"ì=ùñyNá&>Òç¿zj|Ƚv:£‘²=ÓBLK0   \ÓfïÆp†²¹MOck+„o(>U½c.@ÒÀªRƒ0¿D^5ÕÖ>@SÁü»kb>n¼‹=™J¾ç9$¡úÒ±?¿ÌÓ¼?¼âÿÐf¢>&wRÄ°C¾d¬ÿvb'ž=ºAüº]L*¿Ì6UŒ®«>èÄ›‘™X>Pö7]l®½½=iùøÄj¾q.è!®×&>TÔ„ÍDؽN%¥•Ôÿ=oÃö˜Ÿ q=eÅz[¶=ÓBLK0 ÆïÚê…ºb"I×,%„Ü‚˜xÒ„?‘ê`›ÿï¿ÓBLK0 ¯<Ê6YÖ_z›Õ»)J±&‘ê`›ÿï?Ü‚˜xÒ„?ÓBLK0 ÆïÚê…ºb"I×,%„Ü‚˜xÒ„?‘ê`›ÿï¿ÓBLK0 ¯<Ê6YÖ_z›Õ»)J±&‘ê`›ÿï?Ü‚˜xÒ„?#ASDF BLOCK INDEX +%YAML 1.1 +--- +- 38436 +- 38562 +- 38688 +- 38814 +- 38940 +- 39066 +- 39192 +- 39318 +- 39444 +- 39570 +- 39696 +- 39822 +- 39948 +- 40074 +- 40200 +- 40326 +- 40452 +- 40578 +- 40704 +- 40830 +- 40956 +- 41082 +- 41208 +- 41334 +- 41460 +- 41586 +- 41712 +- 41838 +- 41964 +- 42090 +- 42216 +- 42286 +- 42356 +- 42426 +- 42496 +- 42566 +- 42636 +- 42706 +- 42776 +- 42846 +- 42916 +- 43258 +- 43600 +- 43942 +- 44284 +- 44370 +- 44456 +- 44542 +... diff --git a/jwst/wfss_contam/tests/data/segmentation_wcs.asdf b/jwst/wfss_contam/tests/data/segmentation_wcs.asdf new file mode 100644 index 0000000000000000000000000000000000000000..48c6218c8c36a5253a066cc56ad7576001009a09 GIT binary patch literal 4767 zcmdrQOOM+`xJQ%&LgK;&G*wmQK;p!XU$H9CcDIGn?Sg29Kr3X;+LOe>vCYhQo2crI zJ3j(d{EVLX1BD;J0f`F_i5oY*&(E=&P%T{{gb(p}=6k;<*ADI<-6py=)uw}Mxc2Wq zI5;^v`1lAmE!e#N*}=VISnGr94-VcVSBa!cPe>6&8Vhreif}O*jK^bgHAp#|;P)7y z0TY*g9G!8>w_bHD(uI6TIRg;jNFLE#u2Sxi2P;MjNsib#i_?s8F(RvU%`&=V9vQBr z%)H5D8Oc?CrUmJ`WVA^6n#y8591RFb=o+Ar@raxQjVMh#G6QLJEoA4$+e_VI?ra)9 z9;_lEQ;-Xg*haDhK?Q!*iZr|ann@Z`NlTdlL_zU`mzNn)|HVnqyp1Rm?I)GLK(M zW33TQ(fMM!LbQ}l)KK3-$|PNHC(DO{m^9=?@eP_A7t$PwbTl>~?EbSV2WiT~D4~*d zG*`Bj2>jo<{?bBSarn;6f0-1&^njO~CIS;cjrC^pl2aYB0V>6fZUHY64Nl7BdI#!3 zSv7`3SZ2|N#lBe449e=Msay7l7`knJ4rfBLtge{5XyQ2#g^WX@ge-VboOkjs&Wx~( z?-HC4x?#dcVLoQ`4H+EnH>u+m#d$x6sQ&?3(yM=TsF=0;ptI z%{od$9CWdIcw!YT(Auq;kaUz}xh%q<6{ZfG%FK`+m$=8Q=Av??gcTOPsG#v_dSqUXdO7iwd-#hJ8D_Xg7I8|YZ7K%E^RumFRlLTLlrig_*G0{Milzot9adT3^AuRX zTnO_VO4}jOdO@vB$e55K3*rLxA{weOjASlNLsh+AlCFXD7G{5K_36CVZH zD`Bf|FSUDEo%F`D7O|I53acJc#~S}M!=cx#(tIH6i5r@WE+_Rtz8{&&?Cr$QB ztePiVMI$0taLrQv-;umXMK;3;pZWNLjyv!0%LIVvE?%tBclA@~=@DqMt~s+D%e8gG zovX34*JiDk_Y1cN`NgO$Vkap79ua)>vZG()EWa26%O0?E1dP*UM4+tJp19x`2=kZ5 zzSKA!k!l&Tpr%z8nujTm!^&HnEf{A>z)HEGkQbUNw{!UKqqy(GaMxbk3MIrO3OG*G z&G3ufFOGaFp_iQHaLMz?qsJAqee-6>mOG|R*e&I5#x0K|c^nTkO&ffF^Z4$S>6Pw7 zs?X;A$t&Ny`r}*qA6HKG-#@zki*@|u>o>pq_P2L$-1*_xuTGymt>4xD)16dbtEe{b z4}N~#{^{ Date: Wed, 10 Apr 2024 15:26:33 -0400 Subject: [PATCH 06/11] more refactoring of disperse, and unit tests --- jwst/wfss_contam/disperse.py | 41 +++++++++----- jwst/wfss_contam/observations.py | 23 ++++---- jwst/wfss_contam/tests/test_disperse.py | 59 ++++++++------------- jwst/wfss_contam/tests/test_observations.py | 43 +++++++++++++++ 4 files changed, 106 insertions(+), 60 deletions(-) diff --git a/jwst/wfss_contam/disperse.py b/jwst/wfss_contam/disperse.py index 8a4ecde4a0..d76b5a1db4 100644 --- a/jwst/wfss_contam/disperse.py +++ b/jwst/wfss_contam/disperse.py @@ -44,6 +44,33 @@ def flux(x): return flux +def determine_wl_spacing(dw, lams, oversample_factor): + ''' + Use a natural wavelength scale or the wavelength scale of the input SED/spectrum, + whichever is smaller, divided by oversampling requested + + Parameters + ---------- + dw : float + The natural wavelength scale of the grism image + lams : float array + Array of wavelengths corresponding to the fluxes (flxs) for each pixel. + One wavelength per direct image, so can be a single value. + oversample_factor : int + The amount of oversampling + + Returns + ------- + dlam : float + The wavelength spacing to use for the dispersed pixels + ''' + # + if len(lams) > 1: + input_dlam = np.median(lams[1:] - lams[:-1]) + if input_dlam < dw: + return input_dlam / oversample_factor + return dw / oversample_factor + def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, @@ -137,19 +164,9 @@ def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, dxw = xwmax - xwmin dyw = ywmax - ywmin - # Compute the delta-wave per pixel - dw = np.abs((wmax - wmin) / (dyw - dxw)) - - # Use a natural wavelength scale or the wavelength scale of the input SED/spectrum, - # whichever is smaller, divided by oversampling requested - if len(lams) > 1: - input_dlam = np.median(lams[1:] - lams[:-1]) - if input_dlam < dw: - dlam = input_dlam / oversample_factor - else: - dlam = dw / oversample_factor - # Create list of wavelengths on which to compute dispersed pixels + dw = np.abs((wmax - wmin) / (dyw - dxw)) + dlam = determine_wl_spacing(dw, lams, oversample_factor) lambdas = np.arange(wmin, wmax + dlam, dlam) n_lam = len(lambdas) diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index 8e75648e35..fe7606104c 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -247,16 +247,16 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): # Loop over all source ID's from segmentation map pool_args = [] for i in range(len(self.IDs)): - if self.cache: - self.cached_object[i] = {} - self.cached_object[i]['x'] = [] - self.cached_object[i]['y'] = [] - self.cached_object[i]['f'] = [] - self.cached_object[i]['w'] = [] - self.cached_object[i]['minx'] = [] - self.cached_object[i]['maxx'] = [] - self.cached_object[i]['miny'] = [] - self.cached_object[i]['maxy'] = [] + #if self.cache: + # self.cached_object[i] = {} + # self.cached_object[i]['x'] = [] + # self.cached_object[i]['y'] = [] + # self.cached_object[i]['f'] = [] + # self.cached_object[i]['w'] = [] + # self.cached_object[i]['minx'] = [] + # self.cached_object[i]['maxx'] = [] + # self.cached_object[i]['miny'] = [] + # self.cached_object[i]['maxy'] = [] disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp, self.IDs[i], self.xs[i], self.ys[i], @@ -429,6 +429,7 @@ def disperse_chunk(c, order, wmin, wmax, sens_waves, sens_resp, sid, xs, ys, flu return (this_object, thisobj_bounds, sid, order) return (this_object, None, sid, order) + @staticmethod def construct_slitmodel_for_chunk(chunk_data, bounds, sid, order): ''' @@ -450,8 +451,8 @@ def construct_slitmodel_for_chunk(chunk_data, bounds, sid, order): ''' if bounds is None: return None + [thisobj_minx, thisobj_maxx, thisobj_miny, thisobj_maxy] = bounds - slit = datamodels.SlitModel() slit.source_id = sid slit.name = f"source_{sid}" diff --git a/jwst/wfss_contam/tests/test_disperse.py b/jwst/wfss_contam/tests/test_disperse.py index 9e1a98976f..e88c1c9adb 100644 --- a/jwst/wfss_contam/tests/test_disperse.py +++ b/jwst/wfss_contam/tests/test_disperse.py @@ -1,46 +1,31 @@ import pytest import numpy as np -from jwst.wfss_contam.disperse import dispersed_pixel -from jwst.wfss_contam.tests.test_observations import grism_wcs, segmentation_map, direct_image +from jwst.wfss_contam.disperse import interpolate_fluxes, determine_wl_spacing +''' +Note that main disperse.py call is tested in test_observations.py because +it requires all the fixtures defined there. +''' -def test_oversample_same_result(grism_wcs, segmentation_map): - ''' - Coverage for bug where wavelength oversampling led to double-counted fluxes +@pytest.mark.parametrize("lams, flxs, extrapolate_sed, expected_outside_bounds", + [([1, 3], [1, 3], False, 0), + ([2], [2], False, 2), + ([1, 3], [1, 3], True, 4)]) +def test_interpolate_fluxes(lams, flxs, extrapolate_sed, expected_outside_bounds): - note: segmentation_map fixture needs to be able to find module-scoped direct_image - fixture, so it must be imported here - ''' + flux_interpf = interpolate_fluxes(lams, flxs, extrapolate_sed) + assert flux_interpf(2.0) == 2.0 + assert flux_interpf(4.0) == expected_outside_bounds - # manual input of input params set the same as test_observations.py - x0 = 300.5 - y0 = 300.5 - order = 1 - width = 1.0 - height = 1.0 - lams = [2.0] - flxs = [1.0] - ID = 0 - naxis = (300, 500) - sens_waves = np.linspace(1.708, 2.28, 100) - wmin, wmax = np.min(sens_waves), np.max(sens_waves) - sens_resp = np.ones(100) - seg_wcs = segmentation_map.meta.wcs - 0, (300, 500), 2, False, - xoffset = 2200 - yoffset = 1000 - - xs, ys, areas, lams_out, counts_1, ID = dispersed_pixel( - x0, y0, width, height, lams, flxs, order, wmin, wmax, - sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, - oversample_factor=1, extrapolate_sed=False, xoffset=xoffset, - yoffset=yoffset) +@pytest.mark.parametrize("lams, expected_dw", + [([1, 1.2, 1.4], 0.05), + ([1, 1.02, 1.04], 0.01) + ]) +def test_determine_wl_spacing(lams, expected_dw): - xs, ys, areas, lams_out, counts_3, ID = dispersed_pixel( - x0, y0, width, height, lams, flxs, order, wmin, wmax, - sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, - oversample_factor=3, extrapolate_sed=False, xoffset=xoffset, - yoffset=yoffset) + dw = 0.1 + oversample_factor = 2 + dw_out = determine_wl_spacing(dw, np.array(lams), oversample_factor) - assert np.isclose(np.sum(counts_1), np.sum(counts_3), rtol=1e-2) \ No newline at end of file + assert np.isclose(dw_out, expected_dw, atol=1e-8) diff --git a/jwst/wfss_contam/tests/test_observations.py b/jwst/wfss_contam/tests/test_observations.py index 496089b94e..70d51f7c5d 100644 --- a/jwst/wfss_contam/tests/test_observations.py +++ b/jwst/wfss_contam/tests/test_observations.py @@ -10,6 +10,7 @@ from photutils.segmentation import SourceFinder from jwst.wfss_contam.observations import background_subtract, _select_ids, Observation +from jwst.wfss_contam.disperse import dispersed_pixel from jwst.wfss_contam.tests import data from jwst.datamodels import SegmentationMapModel, ImageModel @@ -189,3 +190,45 @@ def test_disperse_chunk_null(observation): assert chunk_bounds is None assert np.all(chunk == 0) + + +def test_disperse_oversample_same_result(grism_wcs, segmentation_map): + ''' + Coverage for bug where wavelength oversampling led to double-counted fluxes + + note: segmentation_map fixture needs to be able to find module-scoped direct_image + fixture, so it must be imported here + ''' + + # manual input of input params set the same as test_observations.py + x0 = 300.5 + y0 = 300.5 + order = 1 + width = 1.0 + height = 1.0 + lams = [2.0] + flxs = [1.0] + ID = 0 + naxis = (300, 500) + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + seg_wcs = segmentation_map.meta.wcs + 0, (300, 500), 2, False, + xoffset = 2200 + yoffset = 1000 + + + xs, ys, areas, lams_out, counts_1, ID = dispersed_pixel( + x0, y0, width, height, lams, flxs, order, wmin, wmax, + sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, + oversample_factor=1, extrapolate_sed=False, xoffset=xoffset, + yoffset=yoffset) + + xs, ys, areas, lams_out, counts_3, ID = dispersed_pixel( + x0, y0, width, height, lams, flxs, order, wmin, wmax, + sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, + oversample_factor=3, extrapolate_sed=False, xoffset=xoffset, + yoffset=yoffset) + + assert np.isclose(np.sum(counts_1), np.sum(counts_3), rtol=1e-2) \ No newline at end of file From 05b9b5d88d5e956fa98e46096085f8d03b0c5115 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 12 Apr 2024 11:28:22 -0400 Subject: [PATCH 07/11] moved multiprocessing back to pixel level, added more unit tests --- jwst/wfss_contam/observations.py | 186 +++++++++----------- jwst/wfss_contam/tests/test_observations.py | 23 +-- jwst/wfss_contam/tests/test_wfss_contam.py | 53 +++++- jwst/wfss_contam/wfss_contam.py | 185 ++++++++++++------- 4 files changed, 263 insertions(+), 184 deletions(-) diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index fe7606104c..e1e192cbbc 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -9,7 +9,6 @@ from .disperse import dispersed_pixel import logging -import warnings from photutils.background import Background2D, MedianBackground from astropy.stats import SigmaClip @@ -247,36 +246,29 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): # Loop over all source ID's from segmentation map pool_args = [] for i in range(len(self.IDs)): - #if self.cache: - # self.cached_object[i] = {} - # self.cached_object[i]['x'] = [] - # self.cached_object[i]['y'] = [] - # self.cached_object[i]['f'] = [] - # self.cached_object[i]['w'] = [] - # self.cached_object[i]['minx'] = [] - # self.cached_object[i]['maxx'] = [] - # self.cached_object[i]['miny'] = [] - # self.cached_object[i]['maxy'] = [] - - disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp, - self.IDs[i], self.xs[i], self.ys[i], - self.fluxes, - self.seg_wcs, self.grism_wcs, self.dims, - self.extrapolate_sed, self.xoffset, self.yoffset] + + if self.cache: + self.cached_object[i] = {} + self.cached_object[i]['x'] = [] + self.cached_object[i]['y'] = [] + self.cached_object[i]['f'] = [] + self.cached_object[i]['w'] = [] + self.cached_object[i]['minx'] = [] + self.cached_object[i]['maxx'] = [] + self.cached_object[i]['miny'] = [] + self.cached_object[i]['maxy'] = [] + + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp,] pool_args.append(disperse_chunk_args) - # call disperse_chunk with optional multiprocessing t0 = time.time() if self.max_cpu > 1: + # put this log message here to avoid printing it for every chunk log.info(f"Using multiprocessing with {self.max_cpu} cores to compute dispersion") - ctx = mp.get_context("forkserver") - with ctx.Pool(self.max_cpu) as mypool: - disperse_chunk_output = mypool.starmap(self.disperse_chunk, pool_args) - else: - disperse_chunk_output = [] - for i in range(len(self.IDs)): - disperse_chunk_output.append(self.disperse_chunk(*pool_args[i])) + disperse_chunk_output = [] + for i in range(len(self.IDs)): + disperse_chunk_output.append(self.disperse_chunk(*pool_args[i])) t1 = time.time() log.info(f"Wall clock time for disperse_chunk order {order}: {(t1-t0):.1f} sec") @@ -290,19 +282,17 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): self.simul_slits_order.append(this_order) self.simul_slits_sid.append(this_sid) - @staticmethod - def disperse_chunk(c, order, wmin, wmax, sens_waves, sens_resp, sid, xs, ys, fluxes_dict, seg_wcs, grism_wcs, dims, extrapolate_sed, xoffset, yoffset): + + def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): """ Method that computes dispersion for a single source. To be called after create_pixel_list(). - Static method to enable parallelization - Parameters ---------- c : int - Chunk (source) number to process. used to index the fluxes dict + Chunk (source) number to process order : int - Spectral order to process + Spectral order number to process wmin : float Minimum wavelength for dispersed spectra wmax : float @@ -310,80 +300,72 @@ def disperse_chunk(c, order, wmin, wmax, sens_waves, sens_resp, sid, xs, ys, flu sens_waves : float array Wavelength array from photom reference file sens_resp : float array - Response (flux calibration) array from photom reference file - sid : int - Source ID - xs : np.ndarray - X-coordinates of the the central pixel of the group of pixels - surrounding the direct image pixel index - ys : np.ndarray - Y-coordinates of the the central pixel of the group of pixels - surrounding the direct image pixel index - fluxes_dict : dict - Dictionary of fluxes for each direct image. - fluxes_dict{"lams"} is the array of wavelengths previously stored in flux list - and correspond to the central wavelengths of the filters used in - the input direct image(s). For the simple case of 1 combined direct image, - this contains a single value (e.g. 4.44 for F444W). - fluxes_dict{"fluxes"} is the array of pixel values from the direct image(s). - For the simple case of 1 combined direct image, this contains a - a single value (just like "lams"). - seg_wcs : gwcs object - WCS object from segmentation map - grism_wcs : gwcs object - WCS object from grism image - dims : tuple - Dimensions of the grism image - extrapolate_sed : bool - Flag indicating whether to extrapolate wavelength range of SED - xoffset : int - Pixel offset to apply when computing the dispersion (accounts for offset from source cutout to full frame) - yoffset : int - Pixel offset to apply when computing the dispersion (accounts for offset from source cutout to full frame) + Response (flux calibration) array from photom reference file Returns ------- this_object : np.ndarray - Dispersed model of segmentation map source - bounds : list - The bounds of the object + 2D array of dispersed pixel values for the source + thisobj_bounds : list + [minx, maxx, miny, maxy] bounds of the object sid : int - The source ID + Source ID order : int - The spectral order number + Spectral order number """ - log.info(f"Dispersing source {sid}, order {order}") - # Loop over all pixels in list for object "c" - log.debug(f"source {sid} contains {len(xs)} pixels") - all_res = [] - for i in range(len(xs)): - # Here "i" indexes the pixel list for the segment - # being processed, as opposed to the ID number of the segment + sid = int(self.IDs[c]) + self.order = order + self.wmin = wmin + self.wmax = wmax + self.sens_waves = sens_waves + self.sens_resp = sens_resp + log.info(f"Dispersing source {sid}, order {self.order}") + pars = [] # initialize params for this object + # Loop over all pixels in list for object "c" + log.debug(f"source contains {len(self.xs[c])} pixels") + for i in range(len(self.xs[c])): + # Here "i" is just an index into the pixel list for the object + # being processed, as opposed to the ID number of the object itself + # xc, yc are the coordinates of the central pixel of the group + # of pixels surrounding the direct image pixel index width = 1.0 height = 1.0 - xc = xs[i] + 0.5 * width - yc = ys[i] + 0.5 * height - + xc = self.xs[c][i] + 0.5 * width + yc = self.ys[c][i] + 0.5 * height + # "lams" is the array of wavelengths previously stored in flux list + # and correspond to the central wavelengths of the filters used in + # the input direct image(s). For the simple case of 1 combined direct image, + # this contains a single value (e.g. 4.44 for F444W). + # "fluxes" is the array of pixel values from the direct image(s). + # For the simple case of 1 combined direct image, this contains a + # a single value (just like "lams"). fluxes, lams = map(np.array, zip(*[ - (fluxes_dict[lm][c][i], lm) for lm in sorted(fluxes_dict.keys()) - if fluxes_dict[lm][c][i] != 0 + (self.fluxes[lm][c][i], lm) for lm in sorted(self.fluxes.keys()) + if self.fluxes[lm][c][i] != 0 ])) - - pars_i = (xc, yc, width, height, lams, fluxes, order, - wmin, wmax, sens_waves, sens_resp, - seg_wcs, grism_wcs, i, dims[::-1], 2, - extrapolate_sed, xoffset, yoffset) - if i == 0: - print(pars_i) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in scalar divide") - warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice") - all_res.append(dispersed_pixel(*pars_i)) + pars_i = [xc, yc, width, height, lams, fluxes, self.order, + self.wmin, self.wmax, self.sens_waves, self.sens_resp, + self.seg_wcs, self.grism_wcs, i, self.dims[::-1], 2, + self.extrapolate_sed, self.xoffset, self.yoffset] + pars.append(pars_i) + #if i == 0: + # print([type(arg) for arg in pars_i]) #all these need to be pickle-able + + # pass parameters into dispersed_pixel, either using multiprocessing or not + time1 = time.time() + if self.max_cpu > 1: + ctx = mp.get_context("forkserver") + with ctx.Pool(self.max_cpu) as mypool: + all_res = mypool.starmap(dispersed_pixel, pars) + else: + all_res = [] + for i in range(len(pars)): + all_res.append(dispersed_pixel(*pars[i])) # Initialize blank image for this source - this_object = np.zeros(dims, float) + this_object = np.zeros(self.dims, float) nres = 0 bounds = [] for pp in all_res: @@ -403,21 +385,23 @@ def disperse_chunk(c, order, wmin, wmax, sens_waves, sens_resp, sid, xs, ys, flu maxy = int(max(y)) a = sparse.coo_matrix((f, (y - miny, x - minx)), shape=(maxy - miny + 1, maxx - minx + 1)).toarray() - bounds.append([minx, maxx, miny, maxy]) - + # Accumulate results into simulated images this_object[miny:maxy + 1, minx:maxx + 1] += a + bounds.append([minx, maxx, miny, maxy]) - #if self.cache: - # self.cached_object[c]['x'].append(x) - # self.cached_object[c]['y'].append(y) - # self.cached_object[c]['f'].append(f) - # self.cached_object[c]['w'].append(w) - # self.cached_object[c]['minx'].append(minx) - # self.cached_object[c]['maxx'].append(maxx) - # self.cached_object[c]['miny'].append(miny) - # self.cached_object[c]['maxy'].append(maxy) + if self.cache: + self.cached_object[c]['x'].append(x) + self.cached_object[c]['y'].append(y) + self.cached_object[c]['f'].append(f) + self.cached_object[c]['w'].append(w) + self.cached_object[c]['minx'].append(minx) + self.cached_object[c]['maxx'].append(maxx) + self.cached_object[c]['miny'].append(miny) + self.cached_object[c]['maxy'].append(maxy) + time2 = time.time() + log.debug(f"Elapsed time {time2-time1} sec") # figure out global bounds of object if len(bounds) > 0: bounds = np.array(bounds) diff --git a/jwst/wfss_contam/tests/test_observations.py b/jwst/wfss_contam/tests/test_observations.py index 70d51f7c5d..baf0e3d425 100644 --- a/jwst/wfss_contam/tests/test_observations.py +++ b/jwst/wfss_contam/tests/test_observations.py @@ -53,8 +53,8 @@ def segmentation_map(direct_image): # turn this into a jwst datamodel model = SegmentationMapModel(data=segm.data) - asdf_file = asdf.open(os.path.join(data_path, "segmentation_wcs.asdf")) - wcsobj = asdf_file.tree['wcs'] + with asdf.open(os.path.join(data_path, "segmentation_wcs.asdf")) as asdf_file: + wcsobj = asdf_file.tree['wcs'] model.meta.wcs = wcsobj return model @@ -62,8 +62,8 @@ def segmentation_map(direct_image): @pytest.fixture(scope='module') def grism_wcs(): - asdf_file = asdf.open(os.path.join(data_path, "grism_wcs.asdf")) - wcsobj = asdf_file.tree['wcs'] + with asdf.open(os.path.join(data_path, "grism_wcs.asdf")) as asdf_file: + wcsobj = asdf_file.tree['wcs'] return wcsobj @@ -119,8 +119,6 @@ def test_create_pixel_list(observation, segmentation_map): def test_disperse_chunk(observation): ''' - disperse_chunk is a static method so need to give it lots of observation attributes as input - Note: it's not obvious how to get a trivial flux example from first principles even setting all input fluxes in dict to 1, because transforms change pixel areas in nontrivial ways. seems a bad idea to write a test that @@ -142,12 +140,7 @@ def test_disperse_chunk(observation): # set all fluxes to unity to try to make a trivial example obs.fluxes[2.0][i] = np.ones(obs.fluxes[2.0][i].shape) - disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp, - obs.IDs[i], obs.xs[i], obs.ys[i], - obs.fluxes, - obs.seg_wcs, obs.grism_wcs, obs.dims, - obs.extrapolate_sed, obs.xoffset, obs.yoffset] - + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp] (chunk, chunk_bounds, sid, order_out) = obs.disperse_chunk(*disperse_chunk_args) #trivial bookkeeping @@ -180,11 +173,7 @@ def test_disperse_chunk_null(observation): obs.xoffset = 2200 obs.yoffset = 1000 - disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp, - obs.IDs[i], obs.xs[i], obs.ys[i], - obs.fluxes, - obs.seg_wcs, obs.grism_wcs, obs.dims, - obs.extrapolate_sed, obs.xoffset, obs.yoffset] + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp] (chunk, chunk_bounds, sid, order_out) = obs.disperse_chunk(*disperse_chunk_args) diff --git a/jwst/wfss_contam/tests/test_wfss_contam.py b/jwst/wfss_contam/tests/test_wfss_contam.py index 966f54e924..656e78c451 100644 --- a/jwst/wfss_contam/tests/test_wfss_contam.py +++ b/jwst/wfss_contam/tests/test_wfss_contam.py @@ -1,5 +1,8 @@ import pytest -from jwst.wfss_contam.wfss_contam import _determine_multiprocessing_ncores +from jwst.wfss_contam.wfss_contam import _determine_multiprocessing_ncores, _cut_frame_to_match_slit, build_common_slit +from jwst.datamodels import SlitModel +import numpy as np + @pytest.mark.parametrize("max_cores, num_cores, expected", [("none", 4, 1), @@ -11,3 +14,51 @@ def test_determine_multiprocessing_ncores(max_cores, num_cores, expected): assert _determine_multiprocessing_ncores(max_cores, num_cores) == expected +@pytest.fixture(scope="module") +def contam(): + return np.ones((10, 10))*0.1 + +@pytest.fixture(scope="module") +def slit0(): + slit = SlitModel(data=np.ones((5, 3))) + slit.xstart = 2 + slit.ystart = 3 + slit.xsize = 3 + slit.ysize = 5 + return slit + + +@pytest.fixture(scope="module") +def slit1(): + slit = SlitModel(data=np.ones((4, 4))*0.5) + slit.xstart = 3 + slit.ystart = 2 + slit.xsize = 4 + slit.ysize = 4 + return slit + + +def test_cut_frame_to_match_slit(slit0, contam): + cut_contam = _cut_frame_to_match_slit(contam, slit0) + assert cut_contam.shape == (5, 3) + assert np.all(cut_contam == 0.1) + + +def test_build_common_slit(slit0, slit1): + slit0, slit1 = build_common_slit(slit0, slit1) + + # check indexing in metadata + assert slit0.xstart == slit1.xstart + assert slit0.ystart == slit1.ystart + assert slit0.xsize == slit1.xsize + assert slit0.ysize == slit1.ysize + assert slit0.data.shape == slit1.data.shape + + # check data overlap + assert np.count_nonzero(slit0.data) == 15 + assert np.count_nonzero(slit1.data) == 16 + assert np.count_nonzero(slit0.data * slit1.data) == 6 + + # check data values + assert np.all(slit0.data[1:6, 0:3] == 1) + assert np.all(slit1.data[0:4, 1:5] == 0.5) diff --git a/jwst/wfss_contam/wfss_contam.py b/jwst/wfss_contam/wfss_contam.py index b36a513835..6662d3d52d 100644 --- a/jwst/wfss_contam/wfss_contam.py +++ b/jwst/wfss_contam/wfss_contam.py @@ -4,6 +4,7 @@ from stdatamodels.jwst import datamodels from astropy.table import Table +import copy from .observations import Observation from .sens1d import get_photom_data @@ -39,13 +40,115 @@ def _determine_multiprocessing_ncores(max_cores, num_cores): elif max_cores == 'all': ncpus = num_cores else: - ncpus = 1 + raise ValueError(f"Invalid value for max_cores: {max_cores}") log.debug(f"Found {num_cores} cores; using {ncpus}") return ncpus -def contam_corr(input_model, waverange, photom, max_cores, brightest_n=None): +def _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders): + """ + Parameters + ---------- + slit : `~jwst.datamodels.SlitModel` + Source slit model + simul_slit_sids : list + List of source IDs for simulated slits + simul_slit_orders : list + List of spectral orders for simulated slits + + Returns + ------- + good_idx : int + Index of the matching simulated slit in the list of simulated slits + """ + + # Retrieve simulated slit for this source only + sid = slit.source_id + order = slit.meta.wcsinfo.spectral_order + good = (simul_slit_sids == sid) * (simul_slit_orders == order) + if not any(good): + return -1 + return np.where(good)[0][0] + + +def _cut_frame_to_match_slit(contam, slit): + + """Cut out the contamination image to match the extent of the source slit. + + Parameters + ---------- + contam : 2D array + Contamination image for the full grism exposure + slit : `~jwst.datamodels.SlitModel` + Source slit model + + Returns + ------- + cutout : 2D array + Contamination image cutout that matches the extent of the source slit + + """ + x1 = slit.xstart + y1 = slit.ystart + cutout = contam[y1:y1 + slit.ysize, x1:x1 + slit.xsize] + + return cutout + + +def build_common_slit(slit0, slit1): + ''' + put data from the two slits into a common backplane + so outputs have the same dimensions + and alignment is based on slit.xstart, slit.ystart + + Parameters + ---------- + slit0 : SlitModel + First slit model + slit1 : SlitModel + Second slit model + + Returns + ------- + slit0 : SlitModel + First slit model with data updated to common backplane + slit1 : SlitModel + Second slit model with data updated to common backplane + ''' + + data0 = slit0.data + data1 = slit1.data + + shape = (max(data0.shape[0], data1.shape[0]), max(data0.shape[1], data1.shape[1])) + xmin = min(slit0.xstart, slit1.xstart) + ymin = min(slit0.ystart, slit1.ystart) + shape = max(slit0.xsize + slit0.xstart - xmin, + slit1.xsize + slit1.xstart - xmin), \ + max(slit0.ysize + slit0.ystart - ymin, + slit1.ysize + slit1.ystart - ymin) + x0 = slit0.xstart - xmin + y0 = slit0.ystart - ymin + x1 = slit1.xstart - xmin + y1 = slit1.ystart - ymin + + backplane0 = np.zeros(shape).T + backplane0[y0:y0+data0.shape[0], x0:x0+data0.shape[1]] = data0 + backplane1 = np.zeros(shape).T + backplane1[y1:y1+data1.shape[0], x1:x1+data1.shape[1]] = data1 + + slit0.data = backplane0 + slit1.data = backplane1 + for slit in [slit0, slit1]: + slit.xstart = xmin + slit.ystart = ymin + slit.xsize = shape[0] + slit.ysize = shape[1] + + return slit0, slit1 + + +def contam_corr(input_model, waverange, photom, max_cores="none", brightest_n=None): """ The main WFSS contamination correction function @@ -172,76 +275,28 @@ def contam_corr(input_model, waverange, photom, max_cores, brightest_n=None): slits = [] for slit in output_model.slits: - # Retrieve simulated slit for this source only - sid = slit.source_id - order = slit.meta.wcsinfo.spectral_order - good = (simul_slit_sids == sid) * (simul_slit_orders == order) - if not any(good): - log.warning(f"Source {sid} order {order} requested by input slit model \ - but not found in simulated slits") + good_idx = _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders) + if good_idx == -1: + log.warning(f"Source {slit.source_id} order {order} requested by input slit model \ + but not found in simulated slits") continue - else: - print('Subtracting contamination for source', sid, 'order', order) - good_idx = np.where(good)[0][0] this_simul = obs.simul_slits.slits[good_idx] - # cut out this source's contamination from the full simulated image - fullframe_sim = np.zeros(obs.dims) - y0 = this_simul.ystart - x0 = this_simul.xstart - fullframe_sim[y0:y0 + this_simul.ysize, x0:x0 + this_simul.xsize] = this_simul.data - contam = simul_all - fullframe_sim - - # Create a cutout of the contam image that matches the extent - # of the source slit - x1 = slit.xstart - 1 - y1 = slit.ystart - 1 - cutout = contam[y1:y1 + slit.ysize, x1:x1 + slit.xsize] - new_slit = datamodels.SlitModel(data=cutout) - # TO DO: - # not sure if the slit metadata is getting transferred properly - copy_slit_info(slit, new_slit) - slits.append(new_slit) - - # Subtract the cutout from the source slit - slit.data -= cutout + # Subtract source slit to make contamination image + # Simulated slits are sometimes different in shape than input data slits by a few pixels + this_simul, slit = build_common_slit(this_simul, slit) + simul_all_cut = _cut_frame_to_match_slit(simul_all, slit) + contam_cut = simul_all_cut - this_simul.data + contam_slit = copy.copy(slit) + contam_slit.data = contam_cut + slits.append(contam_slit) + + # Subtract the contamination from the source slit + slit.data -= contam_cut # Save the contamination estimates for all slits contam_model.slits.extend(slits) - print('number of slits in contam model', len(contam_model.slits)) - print('number of slits in output model', len(output_model.slits)) - print('number of slits in simul model', len(obs.simul_slits.slits)) - - # at what point does the output model get updated with the contamination-corrected data? - # Set the step status to COMPLETE output_model.meta.cal_step.wfss_contam = 'COMPLETE' return output_model, simul_model, contam_model, obs.simul_slits - - -def copy_slit_info(input_slit, output_slit): - - """Copy meta info from one slit to another. - - Parameters - ---------- - input_slit : SlitModel - Input slit model from which slit-specific info will be copied - - output_slit : SlitModel - Output slit model to which slit-specific info will be copied - - """ - output_slit.name = input_slit.name - output_slit.xstart = input_slit.xstart - output_slit.ystart = input_slit.ystart - output_slit.xsize = input_slit.xsize - output_slit.ysize = input_slit.ysize - output_slit.source_id = input_slit.source_id - output_slit.source_type = input_slit.source_type - output_slit.source_xpos = input_slit.source_xpos - output_slit.source_ypos = input_slit.source_ypos - output_slit.meta.wcsinfo.spectral_order = input_slit.meta.wcsinfo.spectral_order - output_slit.meta.wcsinfo.dispersion_direction = input_slit.meta.wcsinfo.dispersion_direction - output_slit.meta.wcs = input_slit.meta.wcs From 18598272b6cb660bf6f1a83a13239da5738decec Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 12 Apr 2024 12:31:32 -0400 Subject: [PATCH 08/11] fix style --- jwst/wfss_contam/observations.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index 01a390af39..e1e192cbbc 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -1,7 +1,6 @@ import time -import multiprocessing import numpy as np -import multiprocessing +import multiprocessing as mp from scipy import sparse @@ -357,7 +356,7 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): # pass parameters into dispersed_pixel, either using multiprocessing or not time1 = time.time() if self.max_cpu > 1: - ctx = multiprocessing.get_context("forkserver") + ctx = mp.get_context("forkserver") with ctx.Pool(self.max_cpu) as mypool: all_res = mypool.starmap(dispersed_pixel, pars) else: From 4b6f2a4c1e3856e64586fa99f14e669db380d833 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Wed, 17 Apr 2024 17:07:27 -0400 Subject: [PATCH 09/11] added type hints and more unit tests --- jwst/wfss_contam/disperse.py | 42 ++++-- jwst/wfss_contam/observations.py | 134 +++++++++++++------- jwst/wfss_contam/tests/test_disperse.py | 4 +- jwst/wfss_contam/tests/test_observations.py | 69 +++++++++- jwst/wfss_contam/tests/test_wfss_contam.py | 24 +++- jwst/wfss_contam/wfss_contam.py | 70 ++++++---- jwst/wfss_contam/wfss_contam_step.py | 4 +- 7 files changed, 262 insertions(+), 85 deletions(-) diff --git a/jwst/wfss_contam/disperse.py b/jwst/wfss_contam/disperse.py index d76b5a1db4..d8150ef864 100644 --- a/jwst/wfss_contam/disperse.py +++ b/jwst/wfss_contam/disperse.py @@ -1,4 +1,6 @@ import numpy as np +from typing import Callable, Sequence +from astropy.wcs import WCS from scipy.interpolate import interp1d import warnings @@ -7,7 +9,10 @@ from .sens1d import create_1d_sens -def interpolate_fluxes(lams, flxs, extrapolate_sed): +def flux_interpolator_injector(lams: np.ndarray, + flxs: np.ndarray, + extrapolate_sed: bool, + ) -> Callable[[float], float]: ''' Parameters ---------- @@ -44,7 +49,10 @@ def flux(x): return flux -def determine_wl_spacing(dw, lams, oversample_factor): +def determine_wl_spacing(dw: float, + lams: np.ndarray, + oversample_factor: int, + ) -> float: ''' Use a natural wavelength scale or the wavelength scale of the input SED/spectrum, whichever is smaller, divided by oversampling requested @@ -72,10 +80,26 @@ def determine_wl_spacing(dw, lams, oversample_factor): return dw / oversample_factor -def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, - sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, - oversample_factor=2, extrapolate_sed=False, xoffset=0, - yoffset=0): +def dispersed_pixel(x0: np.ndarray, + y0: np.ndarray, + width: float, + height: float, + lams: np.ndarray, + flxs: np.ndarray, + order: int, + wmin: float, + wmax: float, + sens_waves: np.ndarray, + sens_resp: np.ndarray, + seg_wcs: WCS, + grism_wcs: WCS, + ID: int, + naxis: Sequence[int], + oversample_factor: int = 2, + extrapolate_sed: bool = False, + xoffset: float = 0, + yoffset: float = 0, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]: """ This function take a list of pixels and disperses them using the information contained in the grism image WCS object and returns a list of dispersed pixels and fluxes. @@ -150,7 +174,7 @@ def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, imgxy_to_grismxy = grism_wcs.get_transform('detector', 'grism_detector') # Set up function for retrieving flux values at each dispersed wavelength - flux = interpolate_fluxes(lams, flxs, extrapolate_sed) + flux_interpolator = flux_interpolator_injector(lams, flxs, extrapolate_sed) # Get x/y positions in the grism image corresponding to wmin and wmax: # Start with RA/Dec of the input pixel position in segmentation map, @@ -205,11 +229,11 @@ def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, # values are naturally in units of physical fluxes, so we divide out # the sensitivity (flux calibration) values to convert to units of # countrate (DN/s). - # flux(lams) is either single-valued (for a single direct image) + # flux_interpolator(lams) is either single-valued (for a single direct image) # or an array of the same length as lams (for multiple direct images in different filters) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning, message="divide by zero") - counts = flux(lams) * areas / (sens * oversample_factor) + counts = flux_interpolator(lams) * areas / (sens * oversample_factor) counts[no_cal] = 0. # set to zero where no flux cal info available return xs, ys, areas, lams, counts, ID diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index e1e192cbbc..4923416b98 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -5,6 +5,8 @@ from scipy import sparse from stdatamodels.jwst import datamodels +from astropy.wcs import WCS +from typing import Sequence from .disperse import dispersed_pixel @@ -17,7 +19,12 @@ log.setLevel(logging.DEBUG) -def background_subtract(data, box_size=None, filter_size=(3,3), sigma=3.0, exclude_percentile=30.0): +def background_subtract(data: np.ndarray, + box_size: tuple = None, + filter_size: tuple = (3,3), + sigma: float = 3.0, + exclude_percentile: float = 30.0, + ) -> np.ndarray: """ Simple astropy background subtraction @@ -57,7 +64,7 @@ def background_subtract(data, box_size=None, filter_size=(3,3), sigma=3.0, exclu return data - bkg.background -def _select_ids(ID, all_IDs): +def _select_ids(ID: int, all_IDs: list[int]) -> list[int]: ''' Select the source IDs to be processed based on the input ID parameter. @@ -91,9 +98,19 @@ def _select_ids(ID, all_IDs): class Observation: """This class defines an actual observation. It is tied to a single grism image.""" - def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=None, - sed_file=None, extrapolate_sed=False, - boundaries=[], offsets=[0, 0], renormalize=True, max_cpu=1): + def __init__(self, + direct_images: list[str], + segmap_model: datamodels.SegmentationMapModel, + grism_wcs: WCS, + filter: str, + ID: int = None, + sed_file: str = None, + extrapolate_sed: bool = False, + boundaries: Sequence = [], + offsets: Sequence = [0, 0], + renormalize: bool = True, + max_cpu: int = 1, + ) -> None: """ Initialize all data and metadata for a given observation. Creates lists of @@ -217,7 +234,13 @@ def create_pixel_list(self): for i in range(len(self.IDs)): self.fluxes["sed"].append(dnew[self.ys[i], self.xs[i]]) - def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): + def disperse_all(self, + order: int, + wmin: float, + wmax: float, + sens_waves: np.ndarray, + sens_resp:np.ndarray, + cache=False): """ Compute dispersed pixel values for all sources identified in the segmentation map. @@ -283,7 +306,14 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): self.simul_slits_sid.append(this_sid) - def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): + def disperse_chunk(self, + c: int, + order: int, + wmin: float, + wmax: float, + sens_waves: np.ndarray, + sens_resp: np.ndarray, + ) -> tuple[np.ndarray, list, int, int]: """ Method that computes dispersion for a single source. To be called after create_pixel_list(). @@ -415,7 +445,11 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): @staticmethod - def construct_slitmodel_for_chunk(chunk_data, bounds, sid, order): + def construct_slitmodel_for_chunk(chunk_data: np.ndarray, + bounds: list, + sid: int, + order: int, + ) -> datamodels.SlitModel: ''' Parameters ---------- @@ -450,53 +484,67 @@ def construct_slitmodel_for_chunk(chunk_data, bounds, sid, order): return slit - def disperse_all_from_cache(self, trans=None): - if not self.cache: - return - self.simulated_image = np.zeros(self.dims, float) +# class ObservationFromCache: +# ''' +# this isn't how it should work. If we're going to use a cache, we should +# be checking if a pixel is in the cache before dispersing it. Then load it if it's there, +# otherwise calculate it. The functions below need to be refactored. +# ''' - for i in range(len(self.IDs)): - this_object = self.disperse_chunk_from_cache(i, trans=trans) +# def __init__(self, cache, dims): +# self.cache = cache +# self.dims = dims +# self.simulated_image = np.zeros(self.dims, float) +# self.cached_object = {} - return this_object +# def disperse_all_from_cache(self, trans=None): +# if not self.cache: +# return - def disperse_chunk_from_cache(self, c, trans=None): - """Method that handles the dispersion. To be called after create_pixel_list()""" +# self.simulated_image = np.zeros(self.dims, float) - if not self.cache: - return +# for i in range(len(self.IDs)): +# this_object = self.disperse_chunk_from_cache(i, trans=trans) - time1 = time.time() +# return this_object - # Initialize blank image for this object - this_object = np.zeros(self.dims, float) +# def disperse_chunk_from_cache(self, c, trans=None): +# """Method that handles the dispersion. To be called after create_pixel_list()""" - if trans is not None: - log.debug("Applying a transmission function...") +# if not self.cache: +# return - for i in range(len(self.cached_object[c]['x'])): - x = self.cached_object[c]['x'][i] - y = self.cached_object[c]['y'][i] - f = self.cached_object[c]['f'][i] * 1. - w = self.cached_object[c]['w'][i] +# time1 = time.time() - if trans is not None: - f *= trans(w) +# # Initialize blank image for this object +# this_object = np.zeros(self.dims, float) - minx = self.cached_object[c]['minx'][i] - maxx = self.cached_object[c]['maxx'][i] - miny = self.cached_object[c]['miny'][i] - maxy = self.cached_object[c]['maxy'][i] +# if trans is not None: +# log.debug("Applying a transmission function...") - a = sparse.coo_matrix((f, (y - miny, x - minx)), - shape=(maxy - miny + 1, maxx - minx + 1)).toarray() +# for i in range(len(self.cached_object[c]['x'])): +# x = self.cached_object[c]['x'][i] +# y = self.cached_object[c]['y'][i] +# f = self.cached_object[c]['f'][i] * 1. +# w = self.cached_object[c]['w'][i] - # Accumulate the results into the simulated images - self.simulated_image[miny:maxy + 1, minx:maxx + 1] += a - this_object[miny:maxy + 1, minx:maxx + 1] += a +# if trans is not None: +# f *= trans(w) - time2 = time.time() - log.debug(f"Elapsed time {time2-time1} sec") +# minx = self.cached_object[c]['minx'][i] +# maxx = self.cached_object[c]['maxx'][i] +# miny = self.cached_object[c]['miny'][i] +# maxy = self.cached_object[c]['maxy'][i] + +# a = sparse.coo_matrix((f, (y - miny, x - minx)), +# shape=(maxy - miny + 1, maxx - minx + 1)).toarray() + +# # Accumulate the results into the simulated images +# self.simulated_image[miny:maxy + 1, minx:maxx + 1] += a +# this_object[miny:maxy + 1, minx:maxx + 1] += a + +# time2 = time.time() +# log.debug(f"Elapsed time {time2-time1} sec") - return this_object +# return this_object diff --git a/jwst/wfss_contam/tests/test_disperse.py b/jwst/wfss_contam/tests/test_disperse.py index e88c1c9adb..4ed558fade 100644 --- a/jwst/wfss_contam/tests/test_disperse.py +++ b/jwst/wfss_contam/tests/test_disperse.py @@ -1,6 +1,6 @@ import pytest import numpy as np -from jwst.wfss_contam.disperse import interpolate_fluxes, determine_wl_spacing +from jwst.wfss_contam.disperse import flux_interpolator_injector, determine_wl_spacing ''' Note that main disperse.py call is tested in test_observations.py because @@ -13,7 +13,7 @@ ([1, 3], [1, 3], True, 4)]) def test_interpolate_fluxes(lams, flxs, extrapolate_sed, expected_outside_bounds): - flux_interpf = interpolate_fluxes(lams, flxs, extrapolate_sed) + flux_interpf = flux_interpolator_injector(lams, flxs, extrapolate_sed) assert flux_interpf(2.0) == 2.0 assert flux_interpf(4.0) == expected_outside_bounds diff --git a/jwst/wfss_contam/tests/test_observations.py b/jwst/wfss_contam/tests/test_observations.py index baf0e3d425..99e21d763e 100644 --- a/jwst/wfss_contam/tests/test_observations.py +++ b/jwst/wfss_contam/tests/test_observations.py @@ -12,7 +12,7 @@ from jwst.wfss_contam.observations import background_subtract, _select_ids, Observation from jwst.wfss_contam.disperse import dispersed_pixel from jwst.wfss_contam.tests import data -from jwst.datamodels import SegmentationMapModel, ImageModel +from jwst.datamodels import SegmentationMapModel, ImageModel, MultiSlitModel data_path = os.path.split(os.path.abspath(data.__file__))[0] DIR_IMAGE = "direct_image.fits" @@ -181,6 +181,37 @@ def test_disperse_chunk_null(observation): assert np.all(chunk == 0) +def test_disperse_all(observation): + + obs = observation + order = 1 + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + + # manually change x,y offset because took transform from a real direct image, with different + # pixel 0,0 than the mock data. This puts i=1, order 1 onto the real grism image + obs.xoffset = 2200 + obs.yoffset = 1000 + + # shorten pixel list to make this test take less time + obs.xs = obs.xs[:3] + obs.ys = obs.ys[:3] + obs.fluxes[2.0] = obs.fluxes[2.0][:3] + obs.disperse_all(order, wmin, wmax, sens_waves, sens_resp, cache=False) + + # test simulated image. should be mostly but not all zeros + assert obs.simulated_image.shape == obs.dims + assert not np.allclose(obs.simulated_image, 0.0) + assert np.median(obs.simulated_image) == 0.0 + + # test simulated slits and their associated metadata + # only the second of the two obs IDs is in the simulated image + assert obs.simul_slits_order == [order,]*1 + assert obs.simul_slits_sid == obs.IDs[-1:] + assert type(obs.simul_slits) == MultiSlitModel + + def test_disperse_oversample_same_result(grism_wcs, segmentation_map): ''' Coverage for bug where wavelength oversampling led to double-counted fluxes @@ -220,4 +251,38 @@ def test_disperse_oversample_same_result(grism_wcs, segmentation_map): oversample_factor=3, extrapolate_sed=False, xoffset=xoffset, yoffset=yoffset) - assert np.isclose(np.sum(counts_1), np.sum(counts_3), rtol=1e-2) \ No newline at end of file + assert np.isclose(np.sum(counts_1), np.sum(counts_3), rtol=1e-2) + + +def test_construct_slitmodel_for_chunk(observation): + ''' + test that the chunk is constructed correctly + ''' + obs = observation + i = 1 + order = 1 + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + + # manually change x,y offset because took transform from a real direct image, with different + # pixel 0,0 than the mock data. This puts i=1, order 1 onto the real grism image + obs.xoffset = 2200 + obs.yoffset = 1000 + + # set all fluxes to unity to try to make a trivial example + obs.fluxes[2.0][i] = np.ones(obs.fluxes[2.0][i].shape) + + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp] + (chunk, chunk_bounds, sid, order_out) = obs.disperse_chunk(*disperse_chunk_args) + + slit = obs.construct_slitmodel_for_chunk(chunk, chunk_bounds, sid, order_out) + + # check that the metadata is correct + assert slit.xstart == chunk_bounds[0] + assert slit.xsize == chunk_bounds[1] - chunk_bounds[0] + 1 + assert slit.ystart == chunk_bounds[2] + assert slit.ysize == chunk_bounds[3] - chunk_bounds[2] + 1 + assert slit.source_id == sid + assert slit.meta.wcsinfo.spectral_order == order_out + assert np.allclose(slit.data, chunk[chunk_bounds[2]:chunk_bounds[3]+1, chunk_bounds[0]:chunk_bounds[1]+1]) diff --git a/jwst/wfss_contam/tests/test_wfss_contam.py b/jwst/wfss_contam/tests/test_wfss_contam.py index 656e78c451..a791b0001b 100644 --- a/jwst/wfss_contam/tests/test_wfss_contam.py +++ b/jwst/wfss_contam/tests/test_wfss_contam.py @@ -1,5 +1,5 @@ import pytest -from jwst.wfss_contam.wfss_contam import _determine_multiprocessing_ncores, _cut_frame_to_match_slit, build_common_slit +from jwst.wfss_contam.wfss_contam import determine_multiprocessing_ncores, _cut_frame_to_match_slit, _find_matching_simul_slit, build_common_slit from jwst.datamodels import SlitModel import numpy as np @@ -9,9 +9,11 @@ ("quarter", 4, 1), ("half", 4, 2), ("all", 4, 4), - ("none", 1, 1),]) + ("none", 1, 1), + (None, 1, 1,), + (3, 5, 3)]) def test_determine_multiprocessing_ncores(max_cores, num_cores, expected): - assert _determine_multiprocessing_ncores(max_cores, num_cores) == expected + assert determine_multiprocessing_ncores(max_cores, num_cores) == expected @pytest.fixture(scope="module") @@ -25,6 +27,8 @@ def slit0(): slit.ystart = 3 slit.xsize = 3 slit.ysize = 5 + slit.meta.wcsinfo.spectral_order = 1 + slit.source_id = 1 return slit @@ -38,6 +42,20 @@ def slit1(): return slit +def test_find_matching_simul_slit(slit0): + sids = [0, 1, 1] + orders = [1, 1, 2] + idx = _find_matching_simul_slit(slit0, sids, orders) + assert idx == 1 + + +def test_find_matching_simul_slit_no_match(slit0): + sids = [0, 1, 1] + orders = [1, 2, 2] + idx = _find_matching_simul_slit(slit0, sids, orders) + assert idx == -1 + + def test_cut_frame_to_match_slit(slit0, contam): cut_contam = _cut_frame_to_match_slit(contam, slit0) assert cut_contam.shape == (5, 3) diff --git a/jwst/wfss_contam/wfss_contam.py b/jwst/wfss_contam/wfss_contam.py index 6662d3d52d..83be7f7160 100644 --- a/jwst/wfss_contam/wfss_contam.py +++ b/jwst/wfss_contam/wfss_contam.py @@ -1,5 +1,6 @@ import logging import multiprocessing +from typing import Union import numpy as np from stdatamodels.jwst import datamodels @@ -13,14 +14,19 @@ log.setLevel(logging.DEBUG) -def _determine_multiprocessing_ncores(max_cores, num_cores): +def determine_multiprocessing_ncores(max_cores: Union[str, int], num_cores) -> int: """Determine the number of cores to use for multiprocessing. Parameters ---------- - max_cores : string - See docstring of contam_corr + max_cores : string or int + Number of cores to use for multiprocessing. If set to 'none' + (the default), then no multiprocessing will be done. The other + allowable string values are 'quarter', 'half', and 'all', which indicate + the fraction of cores to use for multi-proc. The total number of + cores includes the SMT cores (Hyper Threading for Intel). + If an integer is provided, it will be the exact number of cores used. num_cores : int Number of cores available on the machine @@ -28,25 +34,31 @@ def _determine_multiprocessing_ncores(max_cores, num_cores): ------- ncpus : int Number of cores to use for multiprocessing - """ - if max_cores == 'none': - ncpus = 1 - else: - if max_cores == 'quarter': - ncpus = num_cores // 4 or 1 - elif max_cores == 'half': - ncpus = num_cores // 2 or 1 - elif max_cores == 'all': - ncpus = num_cores - else: + match max_cores: + case 'none': + return 1 + case None: + return 1 + case 'quarter': + return num_cores // 4 or 1 + case 'half': + return num_cores // 2 or 1 + case 'all': + return num_cores + case int(): + if max_cores <= num_cores and max_cores > 0: + return max_cores + log.warning(f"Requested {max_cores} cores exceeds the number of cores available on this machine ({num_cores}). Using all available cores.") + return max_cores + case _: raise ValueError(f"Invalid value for max_cores: {max_cores}") - log.debug(f"Found {num_cores} cores; using {ncpus}") - - return ncpus -def _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders): +def _find_matching_simul_slit(slit: datamodels.SlitModel, + simul_slit_sids: list[int], + simul_slit_orders: list[int], + ) -> int: """ Parameters ---------- @@ -66,13 +78,13 @@ def _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders): # Retrieve simulated slit for this source only sid = slit.source_id order = slit.meta.wcsinfo.spectral_order - good = (simul_slit_sids == sid) * (simul_slit_orders == order) + good = (np.array(simul_slit_sids) == sid) * (np.array(simul_slit_orders) == order) if not any(good): return -1 return np.where(good)[0][0] -def _cut_frame_to_match_slit(contam, slit): +def _cut_frame_to_match_slit(contam: np.ndarray, slit: datamodels.SlitModel) -> np.ndarray: """Cut out the contamination image to match the extent of the source slit. @@ -96,7 +108,9 @@ def _cut_frame_to_match_slit(contam, slit): return cutout -def build_common_slit(slit0, slit1): +def build_common_slit(slit0: datamodels.SlitModel, + slit1: datamodels.SlitModel, + ) -> tuple[datamodels.SlitModel, datamodels.SlitModel]: ''' put data from the two slits into a common backplane so outputs have the same dimensions @@ -148,7 +162,12 @@ def build_common_slit(slit0, slit1): return slit0, slit1 -def contam_corr(input_model, waverange, photom, max_cores="none", brightest_n=None): +def contam_corr(input_model: datamodels.MultiSlitModel, + waverange: datamodels.WavelengthrangeModel, + photom: datamodels.NrcWfssPhotomModel | datamodels.NisWfssPhotomModel, + max_cores: str | int = "none", + brightest_n: int = None, + ) -> tuple[datamodels.MultiSlitModel, datamodels.ImageModel, datamodels.MultiSlitModel, datamodels.MultiSlitModel]: """ The main WFSS contamination correction function @@ -160,12 +179,13 @@ def contam_corr(input_model, waverange, photom, max_cores="none", brightest_n=No Wavelength range reference file model photom : `~jwst.datamodels.NrcWfssPhotomModel` or `~jwst.datamodels.NisWfssPhotomModel` Photom (flux cal) reference file model - max_cores : string + max_cores : string or int Number of cores to use for multiprocessing. If set to 'none' (the default), then no multiprocessing will be done. The other - allowable values are 'quarter', 'half', and 'all', which indicate + allowable string values are 'quarter', 'half', and 'all', which indicate the fraction of cores to use for multi-proc. The total number of cores includes the SMT cores (Hyper Threading for Intel). + If an integer is provided, it will be the exact number of cores used. brightest_n : int Number of sources to simulate. If None, then all sources in the input model will be simulated. Requires loading the source catalog @@ -185,7 +205,7 @@ def contam_corr(input_model, waverange, photom, max_cores="none", brightest_n=No """ num_cores = multiprocessing.cpu_count() - ncpus = _determine_multiprocessing_ncores(max_cores, num_cores) + ncpus = determine_multiprocessing_ncores(max_cores, num_cores) # Initialize output model output_model = input_model.copy() diff --git a/jwst/wfss_contam/wfss_contam_step.py b/jwst/wfss_contam/wfss_contam_step.py index e77c662c5a..b9ebfddc98 100755 --- a/jwst/wfss_contam/wfss_contam_step.py +++ b/jwst/wfss_contam/wfss_contam_step.py @@ -25,7 +25,9 @@ class WfssContamStep(Step): reference_file_types = ['photom', 'wavelengthrange'] - def process(self, input_model, *args, **kwargs): + def process(self, + input_model: str | datamodels.MultiSlitModel, + ) -> datamodels.MultiSlitModel: with datamodels.open(input_model) as dm: From 140f9ef30f0c30fd32d6fd4a773dd7c407d8cc27 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 22 Apr 2024 13:15:26 -0400 Subject: [PATCH 10/11] fixed multiprocessing error, fixed output metadata, fixed output slitmodel shape --- jwst/assign_wcs/niriss.py | 10 +- jwst/wfss_contam/disperse.py | 28 +++- jwst/wfss_contam/observations.py | 14 +- jwst/wfss_contam/tests/test_wfss_contam.py | 64 +++++-- jwst/wfss_contam/wfss_contam.py | 183 ++++++++++++++------- jwst/wfss_contam/wfss_contam_step.py | 9 +- 6 files changed, 217 insertions(+), 91 deletions(-) diff --git a/jwst/assign_wcs/niriss.py b/jwst/assign_wcs/niriss.py index 2da04a5751..9d7aab559a 100644 --- a/jwst/assign_wcs/niriss.py +++ b/jwst/assign_wcs/niriss.py @@ -398,11 +398,11 @@ def wfss(input_model, reference_files): # Get the disperser parameters which are defined as a model for each # spectral order with NIRISSGrismModel(reference_files['specwcs']) as f: - dispx = f.dispx - dispy = f.dispy - displ = f.displ - invdispl = f.invdispl - orders = f.orders + dispx = f.dispx.instance + dispy = f.dispy.instance + displ = f.displ.instance + invdispl = f.invdispl.instance + orders = f.orders.instance fwcpos_ref = f.fwcpos_ref # This is the actual rotation from the input model diff --git a/jwst/wfss_contam/disperse.py b/jwst/wfss_contam/disperse.py index d8150ef864..1145dc8423 100644 --- a/jwst/wfss_contam/disperse.py +++ b/jwst/wfss_contam/disperse.py @@ -1,3 +1,4 @@ +from functools import partial import numpy as np from typing import Callable, Sequence from astropy.wcs import WCS @@ -9,6 +10,25 @@ from .sens1d import create_1d_sens +def flat_lam(fluxes: np.ndarray, lams: np.ndarray) -> np.ndarray: + ''' + Parameters + ---------- + x : float + x-coordinate of the pixel. + lams : float array + Array of wavelengths corresponding to the fluxes (flxs) for each pixel. + One wavelength per direct image, so can be a single value. + + Returns + ------- + lams : float array + Array of wavelengths corresponding to the fluxes (flxs) for each pixel. + One wavelength per direct image, so can be a single value. + ''' + return fluxes[0] + + def flux_interpolator_injector(lams: np.ndarray, flxs: np.ndarray, extrapolate_sed: bool, @@ -34,7 +54,7 @@ def flux_interpolator_injector(lams: np.ndarray, ''' if len(lams) > 1: - # If we have direct image flux values from more than one filter (lambda), + # If we have direct image flux values from more than one filter (lams), # we have the option to extrapolate the fluxes outside the # wavelength range of the direct images if extrapolate_sed is False: @@ -42,11 +62,9 @@ def flux_interpolator_injector(lams: np.ndarray, else: return interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False) else: - # If we only have flux from one lambda, just use that + # If we only have flux from one wavelength, just use that # single flux value at all wavelengths - def flux(x): - return flxs[0] - return flux + return partial(flat_lam, flxs) def determine_wl_spacing(dw: float, diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index 4923416b98..0f02a9d21d 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -1,3 +1,4 @@ +import copy import time import numpy as np import multiprocessing as mp @@ -18,6 +19,15 @@ log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) +def disperse_multiprocess(pars, max_cpu): + + pars = copy.deepcopy(pars) + ctx = mp.get_context("forkserver") + with ctx.Pool(max_cpu) as mypool: + all_res = mypool.starmap(dispersed_pixel, pars) + + return all_res + def background_subtract(data: np.ndarray, box_size: tuple = None, @@ -386,9 +396,7 @@ def disperse_chunk(self, # pass parameters into dispersed_pixel, either using multiprocessing or not time1 = time.time() if self.max_cpu > 1: - ctx = mp.get_context("forkserver") - with ctx.Pool(self.max_cpu) as mypool: - all_res = mypool.starmap(dispersed_pixel, pars) + all_res = disperse_multiprocess(pars, self.max_cpu) else: all_res = [] for i in range(len(pars)): diff --git a/jwst/wfss_contam/tests/test_wfss_contam.py b/jwst/wfss_contam/tests/test_wfss_contam.py index a791b0001b..b1418b77d1 100644 --- a/jwst/wfss_contam/tests/test_wfss_contam.py +++ b/jwst/wfss_contam/tests/test_wfss_contam.py @@ -1,5 +1,5 @@ import pytest -from jwst.wfss_contam.wfss_contam import determine_multiprocessing_ncores, _cut_frame_to_match_slit, _find_matching_simul_slit, build_common_slit +from jwst.wfss_contam.wfss_contam import CommonSlitEncompass, CommonSlitPreferFirst, SlitOverlapError, UnmatchedSlitIDError, determine_multiprocessing_ncores, _cut_frame_to_match_slit, _find_matching_simul_slit from jwst.datamodels import SlitModel import numpy as np @@ -42,6 +42,16 @@ def slit1(): return slit +@pytest.fixture(scope="module") +def slit2(): + slit = SlitModel(data=np.ones((3, 5))*0.1) + slit.xstart = 300 + slit.ystart = 200 + slit.xsize = 5 + slit.ysize = 3 + return slit + + def test_find_matching_simul_slit(slit0): sids = [0, 1, 1] orders = [1, 1, 2] @@ -52,8 +62,8 @@ def test_find_matching_simul_slit(slit0): def test_find_matching_simul_slit_no_match(slit0): sids = [0, 1, 1] orders = [1, 2, 2] - idx = _find_matching_simul_slit(slit0, sids, orders) - assert idx == -1 + with pytest.raises(UnmatchedSlitIDError): + _find_matching_simul_slit(slit0, sids, orders) def test_cut_frame_to_match_slit(slit0, contam): @@ -62,21 +72,45 @@ def test_cut_frame_to_match_slit(slit0, contam): assert np.all(cut_contam == 0.1) -def test_build_common_slit(slit0, slit1): - slit0, slit1 = build_common_slit(slit0, slit1) +def test_common_slit_encompass(slit0, slit1): + slit0_final, slit1_final = CommonSlitEncompass(slit0.copy(), slit1.copy()).match_backplane() # check indexing in metadata - assert slit0.xstart == slit1.xstart - assert slit0.ystart == slit1.ystart - assert slit0.xsize == slit1.xsize - assert slit0.ysize == slit1.ysize - assert slit0.data.shape == slit1.data.shape + assert slit0_final.xstart == slit1_final.xstart + assert slit0_final.ystart == slit1_final.ystart + assert slit0_final.xsize == slit1_final.xsize + assert slit0_final.ysize == slit1_final.ysize + assert slit0_final.data.shape == slit1_final.data.shape # check data overlap - assert np.count_nonzero(slit0.data) == 15 - assert np.count_nonzero(slit1.data) == 16 - assert np.count_nonzero(slit0.data * slit1.data) == 6 + assert np.count_nonzero(slit0_final.data) == 15 + assert np.count_nonzero(slit1_final.data) == 16 + assert np.count_nonzero(slit0_final.data * slit1_final.data) == 6 # check data values - assert np.all(slit0.data[1:6, 0:3] == 1) - assert np.all(slit1.data[0:4, 1:5] == 0.5) + assert np.all(slit0_final.data[1:6, 0:3] == 1) + assert np.all(slit1_final.data[0:4, 1:5] == 0.5) + + +def test_common_slit_prefer(slit0, slit1): + + slit0_final, slit1_final = CommonSlitPreferFirst(slit0.copy(), slit1.copy()).match_backplane() + assert slit0_final.xstart == slit0.xstart + assert slit0_final.ystart == slit0.ystart + assert slit0_final.xsize == slit0.xsize + assert slit0_final.ysize == slit0.ysize + assert slit0_final.data.shape == slit0.data.shape + assert np.all(slit0_final.data == slit0.data) + + assert slit1_final.xstart == slit0.xstart + assert slit1_final.ystart == slit0.ystart + assert slit1_final.xsize == slit0.xsize + assert slit1_final.ysize == slit0.ysize + assert slit1_final.data.shape == slit0.data.shape + assert np.count_nonzero(slit1_final.data) == 6 + + +def test_common_slit_prefer_expected_raise(slit0, slit2): + + with pytest.raises(SlitOverlapError): + CommonSlitPreferFirst(slit0.copy(), slit2.copy()).match_backplane() \ No newline at end of file diff --git a/jwst/wfss_contam/wfss_contam.py b/jwst/wfss_contam/wfss_contam.py index 83be7f7160..4ba5275458 100644 --- a/jwst/wfss_contam/wfss_contam.py +++ b/jwst/wfss_contam/wfss_contam.py @@ -1,6 +1,6 @@ import logging import multiprocessing -from typing import Union +from typing import Protocol, Union import numpy as np from stdatamodels.jwst import datamodels @@ -55,6 +55,10 @@ def determine_multiprocessing_ncores(max_cores: Union[str, int], num_cores) -> i raise ValueError(f"Invalid value for max_cores: {max_cores}") +class UnmatchedSlitIDError(Exception): + pass + + def _find_matching_simul_slit(slit: datamodels.SlitModel, simul_slit_sids: list[int], simul_slit_orders: list[int], @@ -80,7 +84,8 @@ def _find_matching_simul_slit(slit: datamodels.SlitModel, order = slit.meta.wcsinfo.spectral_order good = (np.array(simul_slit_sids) == sid) * (np.array(simul_slit_orders) == order) if not any(good): - return -1 + raise UnmatchedSlitIDError(f"Source ID {sid} order {order} requested by input slit model \ + but not found in simulated slits. Setting contamination correction to zero for that slit.") return np.where(good)[0][0] @@ -108,58 +113,116 @@ def _cut_frame_to_match_slit(contam: np.ndarray, slit: datamodels.SlitModel) -> return cutout -def build_common_slit(slit0: datamodels.SlitModel, - slit1: datamodels.SlitModel, - ) -> tuple[datamodels.SlitModel, datamodels.SlitModel]: +class SlitOverlapError(Exception): + pass + +class CommonSlit(Protocol): ''' - put data from the two slits into a common backplane - so outputs have the same dimensions - and alignment is based on slit.xstart, slit.ystart + class protocol for two slits that represent the same source and order, e.g. data and model + ''' + slit0: datamodels.SlitModel + slit1: datamodels.SlitModel + + def match_backplane(self) -> tuple[datamodels.SlitModel, datamodels.SlitModel]: + ... - Parameters - ---------- - slit0 : SlitModel - First slit model - slit1 : SlitModel - Second slit model - Returns - ------- - slit0 : SlitModel - First slit model with data updated to common backplane - slit1 : SlitModel - Second slit model with data updated to common backplane +class CommonSlitPreferFirst(CommonSlit): ''' + Treat slit0 as the reference slit, and match attributes of slit1 to it + ''' + def __init__(self, slit0: datamodels.SlitModel, slit1: datamodels.SlitModel): + self.slit0 = slit0 + self.slit1 = slit1 + + def match_backplane(self) -> tuple[datamodels.SlitModel, datamodels.SlitModel]: + + data0 = self.slit0.data + data1 = self.slit1.data + + x1 = self.slit1.xstart - self.slit0.xstart + y1 = self.slit1.ystart - self.slit0.ystart + backplane1 = np.zeros_like(data0) + + i0 = max([y1,0]) + i1 = min([y1+data1.shape[0], data0.shape[0], data1.shape[0]]) + j0 = max([x1,0]) + j1 = min([x1+data1.shape[1], data0.shape[1], data1.shape[1]]) + if i0 >= i1 or j0 >= j1: + raise SlitOverlapError(f"No overlap region between data and model for slit {self.slit0.sid}, \ + order {self.slit0.meta.spectral_order}. \ + Setting contamination correction to zero for that slit.") + + breakpoint() + backplane1[i0:i1, j0:j1] = data1[i0:i1, j0:j1] + + self.slit1.data = backplane1 + self.slit1.xstart = self.slit0.xstart + self.slit1.ystart = self.slit0.ystart + self.slit1.xsize = self.slit0.xsize + self.slit1.ysize = self.slit0.ysize - data0 = slit0.data - data1 = slit1.data - - shape = (max(data0.shape[0], data1.shape[0]), max(data0.shape[1], data1.shape[1])) - xmin = min(slit0.xstart, slit1.xstart) - ymin = min(slit0.ystart, slit1.ystart) - shape = max(slit0.xsize + slit0.xstart - xmin, - slit1.xsize + slit1.xstart - xmin), \ - max(slit0.ysize + slit0.ystart - ymin, - slit1.ysize + slit1.ystart - ymin) - x0 = slit0.xstart - xmin - y0 = slit0.ystart - ymin - x1 = slit1.xstart - xmin - y1 = slit1.ystart - ymin - - backplane0 = np.zeros(shape).T - backplane0[y0:y0+data0.shape[0], x0:x0+data0.shape[1]] = data0 - backplane1 = np.zeros(shape).T - backplane1[y1:y1+data1.shape[0], x1:x1+data1.shape[1]] = data1 - - slit0.data = backplane0 - slit1.data = backplane1 - for slit in [slit0, slit1]: - slit.xstart = xmin - slit.ystart = ymin - slit.xsize = shape[0] - slit.ysize = shape[1] + return self.slit0, self.slit1 + + +class CommonSlitEncompass(CommonSlit): + ''' + Encompass the data from both slits in a common backplane + ''' + def __init__(self, slit0: datamodels.SlitModel, slit1: datamodels.SlitModel): + self.slit0 = slit0 + self.slit1 = slit1 - return slit0, slit1 + def match_backplane(self) -> tuple[datamodels.SlitModel, datamodels.SlitModel]: + ''' + put data from the two slits into a common backplane + so outputs have the same dimensions + and alignment is based on slit.xstart, slit.ystart + + Parameters + ---------- + slit0 : SlitModel + First slit model + slit1 : SlitModel + Second slit model + + Returns + ------- + slit0 : SlitModel + First slit model with data updated to common backplane + slit1 : SlitModel + Second slit model with data updated to common backplane + ''' + + data0 = self.slit0.data + data1 = self.slit1.data + + shape = (max(data0.shape[0], data1.shape[0]), max(data0.shape[1], data1.shape[1])) + xmin = min(self.slit0.xstart, self.slit1.xstart) + ymin = min(self.slit0.ystart, self.slit1.ystart) + shape = max(self.slit0.xsize + self.slit0.xstart - xmin, + self.slit1.xsize + self.slit1.xstart - xmin), \ + max(self.slit0.ysize + self.slit0.ystart - ymin, + self.slit1.ysize + self.slit1.ystart - ymin) + x0 = self.slit0.xstart - xmin + y0 = self.slit0.ystart - ymin + x1 = self.slit1.xstart - xmin + y1 = self.slit1.ystart - ymin + + backplane0 = np.zeros(shape).T + backplane0[y0:y0+data0.shape[0], x0:x0+data0.shape[1]] = data0 + backplane1 = np.zeros(shape).T + backplane1[y1:y1+data1.shape[0], x1:x1+data1.shape[1]] = data1 + + self.slit0.data = backplane0 + self.slit1.data = backplane1 + for slit in [self.slit0, self.slit1]: + slit.xstart = xmin + slit.ystart = ymin + slit.xsize = shape[0] + slit.ysize = shape[1] + + return self.slit0, self.slit1 def contam_corr(input_model: datamodels.MultiSlitModel, @@ -259,6 +322,7 @@ def contam_corr(input_model: datamodels.MultiSlitModel, good_slits = [slit for slit in output_model.slits if slit.source_id in obs.IDs] output_model = datamodels.MultiSlitModel() + output_model.update(input_model) output_model.slits.extend(good_slits) log.info(f"Simulating only the brightest {brightest_n} sources") @@ -295,18 +359,17 @@ def contam_corr(input_model: datamodels.MultiSlitModel, slits = [] for slit in output_model.slits: - good_idx = _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders) - if good_idx == -1: - log.warning(f"Source {slit.source_id} order {order} requested by input slit model \ - but not found in simulated slits") - continue - this_simul = obs.simul_slits.slits[good_idx] - - # Subtract source slit to make contamination image - # Simulated slits are sometimes different in shape than input data slits by a few pixels - this_simul, slit = build_common_slit(this_simul, slit) - simul_all_cut = _cut_frame_to_match_slit(simul_all, slit) - contam_cut = simul_all_cut - this_simul.data + try: + good_idx = _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders) + this_simul = obs.simul_slits.slits[good_idx] + slit, this_simul = CommonSlitPreferFirst(slit, this_simul).match_backplane() + simul_all_cut = _cut_frame_to_match_slit(simul_all, slit) + contam_cut = simul_all_cut - this_simul.data + + except (UnmatchedSlitIDError, SlitOverlapError) as e: + log.warning(e) + contam_cut = np.zeros_like(slit.data) + contam_slit = copy.copy(slit) contam_slit.data = contam_cut slits.append(contam_slit) diff --git a/jwst/wfss_contam/wfss_contam_step.py b/jwst/wfss_contam/wfss_contam_step.py index b9ebfddc98..d73afd3671 100755 --- a/jwst/wfss_contam/wfss_contam_step.py +++ b/jwst/wfss_contam/wfss_contam_step.py @@ -1,4 +1,5 @@ #! /usr/bin/env python +import copy from stdatamodels.jwst import datamodels from ..stpipe import Step @@ -31,8 +32,6 @@ def process(self, with datamodels.open(input_model) as dm: - max_cores = self.maximum_cores - # Get the wavelengthrange ref file waverange_ref = self.get_reference_file(dm, 'wavelengthrange') self.log.info(f'Using WAVELENGTHRANGE reference file {waverange_ref}') @@ -46,7 +45,7 @@ def process(self, result, simul, contam, simul_slits = wfss_contam.contam_corr(dm, waverange_model, photom_model, - max_cores, + self.maximum_cores, brightest_n=self.brightest_n) # Save intermediate results, if requested @@ -60,4 +59,8 @@ def process(self, self.log.info(f'Contamination estimates saved to "{contam_path}"') # Return the corrected data + + print(input_model.info()) + print(result.info()) + return result From 632fdec928c328fdc982c9967b1e6064c43301b1 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 22 Apr 2024 13:47:53 -0400 Subject: [PATCH 11/11] fixed style check, removed breakpoints and prints, small fix to log message --- jwst/wfss_contam/wfss_contam.py | 5 ++--- jwst/wfss_contam/wfss_contam_step.py | 6 ------ 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/jwst/wfss_contam/wfss_contam.py b/jwst/wfss_contam/wfss_contam.py index 4ba5275458..77b3446475 100644 --- a/jwst/wfss_contam/wfss_contam.py +++ b/jwst/wfss_contam/wfss_contam.py @@ -149,11 +149,10 @@ def match_backplane(self) -> tuple[datamodels.SlitModel, datamodels.SlitModel]: j0 = max([x1,0]) j1 = min([x1+data1.shape[1], data0.shape[1], data1.shape[1]]) if i0 >= i1 or j0 >= j1: - raise SlitOverlapError(f"No overlap region between data and model for slit {self.slit0.sid}, \ - order {self.slit0.meta.spectral_order}. \ + raise SlitOverlapError(f"No overlap region between data and model for slit {self.slit0.source_id}, \ + order {self.slit0.meta.wcsinfo.spectral_order}. \ Setting contamination correction to zero for that slit.") - breakpoint() backplane1[i0:i1, j0:j1] = data1[i0:i1, j0:j1] self.slit1.data = backplane1 diff --git a/jwst/wfss_contam/wfss_contam_step.py b/jwst/wfss_contam/wfss_contam_step.py index d73afd3671..52944538d0 100755 --- a/jwst/wfss_contam/wfss_contam_step.py +++ b/jwst/wfss_contam/wfss_contam_step.py @@ -1,5 +1,4 @@ #! /usr/bin/env python -import copy from stdatamodels.jwst import datamodels from ..stpipe import Step @@ -58,9 +57,4 @@ def process(self, contam_path = self.save_model(contam, suffix="contam", force=True) self.log.info(f'Contamination estimates saved to "{contam_path}"') - # Return the corrected data - - print(input_model.info()) - print(result.info()) - return result