diff --git a/changes/9022.extract_1d.rst b/changes/9022.extract_1d.rst new file mode 100644 index 0000000000..2446dabeba --- /dev/null +++ b/changes/9022.extract_1d.rst @@ -0,0 +1,2 @@ +Expanded the ``use_source_posn`` option to calculate a source trace from WCS and expected source positions for unresampled NIRSpec and MIRI LRS fixed slit data. +Added the step parameter ``position_offset`` to allow an additional aperture offset in pixels. diff --git a/docs/jwst/extract_1d/arguments.rst b/docs/jwst/extract_1d/arguments.rst index 7311608d9a..5d3b47e7ae 100644 --- a/docs/jwst/extract_1d/arguments.rst +++ b/docs/jwst/extract_1d/arguments.rst @@ -26,9 +26,13 @@ Step Arguments for Slit and Slitless Spectroscopic Data file should be shifted to account for the expected position of the source. If None (the default), the step will decide whether to use the source position based on the observing mode and the source type. By default, source position corrections - are attempted only for NIRSpec MOS and NIRSpec and MIRI LRS fixed-slit point sources. + are attempted only for point sources in NIRSpec MOS/FS/BOTS and MIRI LRS fixed-slit exposures. Set to False to ignore position estimates for all modes; set to True to additionally attempt - source position correction for NIRSpec BOTS data or extended sources. + source position correction for extended sources. + +``--position_offset`` + Specify a number of pixels (fractional pixels are allowed) to offset the + extraction aperture from the nominal position. The default is 0. ``--smoothing_length`` If ``smoothing_length`` is greater than 1 (and is an odd integer), the diff --git a/docs/jwst/extract_1d/description.rst b/docs/jwst/extract_1d/description.rst index bd31735236..c6f143bebc 100644 --- a/docs/jwst/extract_1d/description.rst +++ b/docs/jwst/extract_1d/description.rst @@ -160,17 +160,26 @@ If `extract_width` is also given, the start and stop values are used to define the center of the extraction region in the cross-dispersion direction, but the width of the aperture is set by the `extract_width` value. -For some instruments and modes, the cross-dispersion start and stop values may be shifted -to account for the expected location of the source. This option -is available for NIRSpec MOS, fixed-slit, and BOTS data, as well as MIRI LRS fixed-slit. +For some instruments and modes, the extraction region may be adjusted +to account for the expected location of the source with the `use_source_posn` +option. This option is available for NIRSpec MOS, fixed-slit, and BOTS data, +as well as MIRI LRS fixed-slit. If `use_source_posn` is set to None via the reference file or input parameters, -it is turned on by default for all point sources in these modes, except NIRSpec BOTS. -To turn it on for NIRSpec BOTS or extended sources, set `use_source_posn` to True. +it is turned on by default for all point sources in these modes. +To turn it on for extended sources, set `use_source_posn` to True. To turn it off for any mode, set `use_source_posn` to False. -If source position correction is enabled, the planned location for the source is -calculated internally, via header metadata recording the source position and the -spectral WCS transforms, then used to offset the extraction start and stop values -in the cross-dispersion direction. +If source position option is enabled, the planned location for the source and its +trace are calculated internally via header metadata recording the source position +and the spectral WCS transforms. The source location will be used to offset the +extraction start and stop values in the cross-dispersion direction. +If `extract_width` is provided, the source extraction region will be centered +on the calculated trace with a width set by the `extract_width` value. +For resampled, "s2d", products this will effectively be the rectangular +extraction region offset in the cross-dispersion direction. For +"cal" or "calints" products that have not been resampled, the extraction region +will be curved to follow the calculated trace. +If no `extract_width` has been provided, the shifted extraction start and +stop values will be used. A more flexible way to specify the source extraction region is via the `src_coeff` parameter. `src_coeff` is specified as a list of lists of floating-point diff --git a/jwst/extract_1d/extract.py b/jwst/extract_1d/extract.py index 13d24a6811..f78b826e0f 100644 --- a/jwst/extract_1d/extract.py +++ b/jwst/extract_1d/extract.py @@ -4,6 +4,8 @@ from json.decoder import JSONDecodeError from astropy.modeling import polynomial +from gwcs.wcstools import grid_from_bounding_box +from scipy.interpolate import interp1d from stdatamodels.jwst import datamodels from stdatamodels.jwst.datamodels.apcorr import ( MirLrsApcorrModel, MirMrsApcorrModel, NrcWfssApcorrModel, NrsFsApcorrModel, @@ -19,8 +21,8 @@ __all__ = ['run_extract1d', 'read_extract1d_ref', 'read_apcorr_ref', 'get_extract_parameters', 'box_profile', 'aperture_center', - 'location_from_wcs', 'shift_by_source_location', 'define_aperture', - 'extract_one_slit', 'create_extraction'] + 'location_from_wcs', 'shift_by_offset', + 'define_aperture', 'extract_one_slit', 'create_extraction'] log = logging.getLogger(__name__) @@ -29,6 +31,9 @@ WFSS_EXPTYPES = ['NIS_WFSS', 'NRC_WFSS', 'NRC_GRISM'] """Exposure types to be regarded as wide-field slitless spectroscopy.""" +SRCPOS_EXPTYPES = ['MIR_LRS-FIXEDSLIT', 'NRS_FIXEDSLIT', 'NRS_MSASPEC', 'NRS_BRIGHTOBJ'] +"""Exposure types for which source position can be estimated.""" + ANY = "ANY" """Wildcard for slit name. @@ -140,7 +145,8 @@ def read_apcorr_ref(refname, exptype): def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, smoothing_length=None, bkg_fit=None, bkg_order=None, - use_source_posn=None, subtract_background=None): + use_source_posn=None, subtract_background=None, + position_offset=0.0): """Get extraction parameter values. Parameters @@ -203,6 +209,11 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, subtract_background : bool or None, optional If False, all background parameters will be ignored. + position_offset : float or None, optional + Pixel offset to apply to the nominal source location. + If None, the value specified in `ref_dict` will be used or it + will default to 0. + Returns ------- extract_params : dict @@ -232,6 +243,8 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, extract_params['use_source_posn'] = False # no source position correction extract_params['position_correction'] = 0 extract_params['independent_var'] = 'pixel' + extract_params['position_offset'] = 0. + extract_params['trace'] = None # Note that extract_params['dispaxis'] is not assigned. # This will be done later, possibly slit by slit. @@ -305,7 +318,7 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, if use_source_posn is None: # no value set on command line if use_source_posn_aper is None: # no value set in ref file # Use a suitable default - if meta.exposure.type in ['MIR_LRS-FIXEDSLIT', 'NRS_FIXEDSLIT', 'NRS_MSASPEC']: + if meta.exposure.type in SRCPOS_EXPTYPES: use_source_posn = True log.info(f"Turning on source position correction " f"for exp_type = {meta.exposure.type}") @@ -314,6 +327,8 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, else: # use the value from the ref file use_source_posn = use_source_posn_aper extract_params['use_source_posn'] = use_source_posn + extract_params['position_offset'] = position_offset + extract_params['trace'] = None extract_params['extract_width'] = aper.get('extract_width') extract_params['position_correction'] = 0 # default value @@ -363,21 +378,10 @@ def log_initial_parameters(extract_params): return log.debug("Extraction parameters:") - log.debug(f"dispaxis = {extract_params['dispaxis']}") - log.debug(f"spectral order = {extract_params['spectral_order']}") - log.debug(f"initial xstart = {extract_params['xstart']}") - log.debug(f"initial xstop = {extract_params['xstop']}") - log.debug(f"initial ystart = {extract_params['ystart']}") - log.debug(f"initial ystop = {extract_params['ystop']}") - log.debug(f"extract_width = {extract_params['extract_width']}") - log.debug(f"initial src_coeff = {extract_params['src_coeff']}") - log.debug(f"initial bkg_coeff = {extract_params['bkg_coeff']}") - log.debug(f"bkg_fit = {extract_params['bkg_fit']}") - log.debug(f"bkg_order = {extract_params['bkg_order']}") - log.debug(f"smoothing_length = {extract_params['smoothing_length']}") - log.debug(f"independent_var = {extract_params['independent_var']}") - log.debug(f"use_source_posn = {extract_params['use_source_posn']}") - log.debug(f"extraction_type = {extract_params['extraction_type']}") + skip_keys = {'match', 'trace'} + for key, value in extract_params.items(): + if key not in skip_keys: + log.debug(f" {key} = {value}") def create_poly(coeff): @@ -724,9 +728,10 @@ def box_profile(shape, extract_params, wl_array, coefficients='src_coeff', `extract_params`, in this priority order: 1. src_coeff upper and lower limits (or bkg_coeff, for a background profile) - 2. center of start/stop values +/- extraction width / 2 - 3. cross-dispersion start/stop values - 4. array limits. + 2. trace +/- extraction width / 2 + 3. center of start/stop values +/- extraction width / 2 + 4. cross-dispersion start/stop values + 5. array limits. Left and right limits are set from start/stop values only. @@ -792,8 +797,9 @@ def box_profile(shape, extract_params, wl_array, coefficients='src_coeff', # Set aperture region, in this priority order: # 1. src_coeff upper and lower limits (or bkg_coeff, for background profile) - # 2. center of start/stop values +/- extraction width - # 3. start/stop values + # 2. trace +/- extraction width / 2 + # 3. center of start/stop values +/- extraction width / 2 + # 4. start/stop values profile = np.full(shape, 0.0) if extract_params[coefficients] is not None: # Limits from source coefficients: ignore ystart/stop/width @@ -848,6 +854,24 @@ def box_profile(shape, extract_params, wl_array, coefficients='src_coeff', if mean_upper > upper_limit: upper_limit = mean_upper + elif extract_params['extract_width'] is not None and extract_params['trace'] is not None: + width = extract_params['extract_width'] + trace = extract_params['trace'] + + if extract_params['dispaxis'] != HORIZONTAL: + trace = np.tile(trace, (shape[1], 1)).T + + lower_limit_region = trace - (width - 1.0) / 2.0 + upper_limit_region = lower_limit_region + width - 1 + + _set_weight_from_limits(profile, dval, lower_limit_region, + upper_limit_region) + + lower_limit = np.nanmean(lower_limit_region) + upper_limit = np.nanmean(upper_limit_region) + log.info(f'Mean {label} start/stop from trace: ' + f'{lower_limit:.2f} -> {upper_limit:.2f} (inclusive)') + elif extract_params['extract_width'] is not None: # Limits from extraction width at center of ystart/stop if present, # center of array if not @@ -969,7 +993,6 @@ def location_from_wcs(input_model, slit): ---------- input_model : DataModel The input science model containing metadata information. - slit : DataModel or None One slit from a MultiSlitModel (or similar), or None. The WCS and target coordinates will be retrieved from `slit` @@ -985,29 +1008,29 @@ def location_from_wcs(input_model, slit): nominal extraction location, in case it varies along the spectrum. The offset will then be the difference between `location` (below) and the nominal location. - middle_wl : float or None The wavelength at pixel `middle`. - location : float or None Pixel coordinate in the cross-dispersion direction within the spectral image that is at the planned target location. The spectral extraction region should be centered here. + trace : ndarray or None + An array of source positions, one per dispersion element, corresponding + to the location at each point in the wavelength array. If the + input data is resampled, the trace corresponds directly to the + location. """ if slit is not None: - wcs_source = slit + shape = slit.data.shape[-2:] + wcs = slit.meta.wcs + dispaxis = slit.meta.wcsinfo.dispersion_direction else: - wcs_source = input_model - wcs = wcs_source.meta.wcs - dispaxis = wcs_source.meta.wcsinfo.dispersion_direction + shape = input_model.data.shape[-2:] + wcs = input_model.meta.wcs + dispaxis = input_model.meta.wcsinfo.dispersion_direction bb = wcs.bounding_box # ((x0, x1), (y0, y1)) if bb is None: - if slit is None: - shape = input_model.data.shape - else: - shape = slit.data.shape - bb = wcs_bbox_from_shape(shape) if dispaxis == HORIZONTAL: @@ -1033,13 +1056,15 @@ def location_from_wcs(input_model, slit): lower = bb[0][0] upper = bb[0][1] - # We need transform[2], a 1-D array of wavelengths crossing the spectrum - # near its middle. + # Get the wavelengths for the valid data in the sky transform, + # average to get the middle wavelength fwd_transform = wcs(x, y) middle_wl = np.nanmean(fwd_transform[2]) exp_type = input_model.meta.exposure.type + trace = None if exp_type in ['NRS_FIXEDSLIT', 'NRS_MSASPEC', 'NRS_BRIGHTOBJ']: + log.info("Using source_xpos and source_ypos to center extraction.") if slit is None: xpos = input_model.source_xpos ypos = input_model.source_ypos @@ -1050,12 +1075,15 @@ def location_from_wcs(input_model, slit): slit2det = wcs.get_transform('slit_frame', 'detector') if 'gwa' in wcs.available_frames: # Input is not resampled, wavelengths need to be meters - x_y = slit2det(xpos, ypos, middle_wl * 1e-6) + _, location = slit2det(xpos, ypos, middle_wl * 1e-6) else: - x_y = slit2det(xpos, ypos, middle_wl) - log.info("Using source_xpos and source_ypos to center extraction.") + _, location = slit2det(xpos, ypos, middle_wl) + + if ~np.isnan(location): + trace = _nirspec_trace_from_wcs(shape, bb, wcs, xpos, ypos) elif exp_type == 'MIR_LRS-FIXEDSLIT': + log.info("Using dithered_ra and dithered_dec to center extraction.") try: if slit is None: dithra = input_model.meta.dither.dithered_ra @@ -1063,23 +1091,21 @@ def location_from_wcs(input_model, slit): else: dithra = slit.meta.dither.dithered_ra dithdec = slit.meta.dither.dithered_dec - x_y = wcs.backward_transform(dithra, dithdec, middle_wl) + location, _ = wcs.backward_transform(dithra, dithdec, middle_wl) + except (AttributeError, TypeError): log.warning("Dithered pointing location not found in wcsinfo.") - return None, None, None - else: - log.warning(f"Source position cannot be found for EXP_TYPE {exp_type}") - return None, None, None + return None, None, None, None - # location is the XD location of the spectrum: - if dispaxis == HORIZONTAL: - location = x_y[1] + if ~np.isnan(location): + trace = _miri_trace_from_wcs(shape, bb, wcs, dithra, dithdec) else: - location = x_y[0] + log.warning(f"Source position cannot be found for EXP_TYPE {exp_type}") + return None, None, None, None if np.isnan(location): log.warning('Source position could not be determined from WCS.') - return None, None, None + return None, None, None, None # If the target is at the edge of the image or at the edge of the # non-NaN area, we can't use the WCS to find the @@ -1087,48 +1113,38 @@ def location_from_wcs(input_model, slit): if location < lower or location > upper: log.warning(f"WCS implies the target is at {location:.2f}, which is outside the bounding box,") log.warning("so we can't get spectrum location using the WCS") - return None, None, None + return None, None, None, None - return middle, middle_wl, location + return middle, middle_wl, location, trace -def shift_by_source_location(location, nominal_location, extract_params): - """Shift the nominal extraction parameters by the source location. - - The offset applied is `location` - `nominal_location`, along - the cross-dispersion direction. +def shift_by_offset(offset, extract_params, update_trace=True): + """Shift the nominal extraction parameters by a pixel offset. Start, stop, and polynomial coefficient values for source and background are updated in place in the `extract_params` dictionary. + The source trace value, if present, is also updated if desired. Parameters ---------- - location : float - The source location in the cross-dispersion direction - at which to center the extraction aperture. - nominal_location : float - The center of the nominal extraction aperture, in the - cross-dispersion direction, according to the extraction - parameters. + offset : float + Cross-dispersion offset to apply, in pixels. extract_params : dict Extraction parameters to update, as created by - `get_extraction_parameters`, and corresponding to the - specified nominal location. + `get_extraction_parameters`. + update_trace : bool + If True, the trace in `extract_params['trace']` is also updated + if present. """ - - # Get the center of the nominal aperture - offset = location - nominal_location - log.info(f"Nominal location is {nominal_location:.2f}, " - f"so offset is {offset:.2f} pixels") - - # Shift aperture limits by the difference between the - # source location and the nominal center + # Shift polynomial coefficients coeff_params = ['src_coeff', 'bkg_coeff'] for params in coeff_params: if extract_params[params] is not None: for coeff_list in extract_params[params]: coeff_list[0] += offset + + # Shift start/stop values if extract_params['dispaxis'] == HORIZONTAL: start_stop_params = ['ystart', 'ystop'] else: @@ -1137,6 +1153,134 @@ def shift_by_source_location(location, nominal_location, extract_params): if extract_params[params] is not None: extract_params[params] += offset + # Shift the full trace + if update_trace and extract_params['trace'] is not None: + extract_params['trace'] += offset + + +def _nirspec_trace_from_wcs(shape, bounding_box, wcs_ref, source_xpos, source_ypos): + """Calculate NIRSpec source trace from WCS. + + The source trace is calculated by projecting the recorded source + positions source_xpos/ypos from the NIRSpec "slit_frame" onto + detector pixels. + + Parameters + ---------- + shape : tuple of int + 2D shape for the full input data array, (ny, nx). + bounding_box : tuple + A pair of tuples, each consisting of two numbers. + Represents the range of useful pixel values in both dimensions, + ((xmin, xmax), (ymin, ymax)). + wcs_ref : `~gwcs.WCS` + WCS for the input data model, containing slit and detector + transforms. + source_xpos : float + Slit position, in the x direction, for the target. + source_ypos : float + Slit position, in the y direction, for the target. + + Returns + ------- + trace : ndarray of float + Fractional pixel positions in the y (cross-dispersion direction) + of the trace for each x (dispersion direction) pixel. + """ + x, y = grid_from_bounding_box(bounding_box) + nx = int(bounding_box[0][1] - bounding_box[0][0]) + + # Calculate the wavelengths in the slit frame because they are in + # meters for cal files and um for s2d files + d2s = wcs_ref.get_transform("detector", "slit_frame") + _, _, slit_wavelength = d2s(x,y) + + # Make an initial array of wavelengths that will cover the wavelength range of the data + wave_vals = np.linspace(np.nanmin(slit_wavelength), np.nanmax(slit_wavelength), nx) + # Get arrays of the source position in the slit + pos_x = np.full(nx, source_xpos) + pos_y = np.full(nx, source_ypos) + + # Grab the wcs transform between the slit frame where we know the + # source position and the detector frame + s2d = wcs_ref.get_transform("slit_frame", "detector") + + # Calculate the expected center of the source trace + trace_x, trace_y = s2d(pos_x, pos_y, wave_vals) + + # Interpolate the trace to a regular pixel grid in the dispersion + # direction + interp_trace = interp1d(trace_x, trace_y, fill_value='extrapolate') + + # Get the trace position for each dispersion element + trace = interp_trace(np.arange(nx)) + + # Place the trace in the full array + full_trace = np.full(shape[1], np.nan) + x0 = int(np.ceil(bounding_box[0][0])) + full_trace[x0:x0 + nx] = trace + + return full_trace + + +def _miri_trace_from_wcs(shape, bounding_box, wcs_ref, source_ra, source_dec): + """Calculate MIRI LRS fixed slit source trace from WCS. + + The source trace is calculated by projecting the recorded source + positions dithered_ra/dec from the world frame onto detector pixels. + + Parameters + ---------- + shape : tuple of int + 2D shape for the full input data array, (ny, nx). + bounding_box : tuple + A pair of tuples, each consisting of two numbers. + Represents the range of useful pixel values in both dimensions, + ((xmin, xmax), (ymin, ymax)). + wcs_ref : `~gwcs.WCS` + WCS for the input data model, containing sky and detector + transforms, forward and backward. + source_ra : float + RA coordinate for the target. + source_dec : float + Dec coordinate for the target. + + Returns + ------- + trace : ndarray of float + Fractional pixel positions in the x (cross-dispersion direction) + of the trace for each y (dispersion direction) pixel. + """ + x, y = grid_from_bounding_box(bounding_box) + ny = int(bounding_box[1][1] - bounding_box[1][0]) + + # Calculate the wavelengths for the full array + _, _, slit_wavelength = wcs_ref(x, y) + + # Make an initial array of wavelengths that will cover the wavelength range of the data + wave_vals = np.linspace(np.nanmin(slit_wavelength), np.nanmax(slit_wavelength), ny) + + # Get arrays of the source position + pos_ra = np.full(ny, source_ra) + pos_dec = np.full(ny, source_dec) + + # Calculate the expected center of the source trace + trace_x, trace_y = wcs_ref.backward_transform(pos_ra, pos_dec, wave_vals) + + # Interpolate the trace to a regular pixel grid in the dispersion + # direction + interp_trace = interp1d(trace_y, trace_x, fill_value='extrapolate') + + # Get the trace position for each dispersion element within the bounding box + trace = interp_trace(np.arange(ny)) + + # Place the trace in the full array + full_trace = np.full(shape[0], np.nan) + y0 = int(np.ceil(bounding_box[1][0])) + full_trace[y0:y0 + ny] = trace + + return full_trace + def define_aperture(input_model, slit, extract_params, exp_type): """Define an extraction aperture from input parameters. @@ -1198,7 +1342,7 @@ def define_aperture(input_model, slit, extract_params, exp_type): # Extract parameters are updated in place if extract_params['use_source_posn']: # Source location from WCS - middle_pix, middle_wl, location = location_from_wcs(input_model, slit) + middle_pix, middle_wl, location, trace = location_from_wcs(input_model, slit) if location is not None: log.info(f"Computed source location is {location:.2f}, " @@ -1210,8 +1354,22 @@ def define_aperture(input_model, slit, extract_params, exp_type): nominal_location, _ = aperture_center( nominal_profile, extract_params['dispaxis'], middle_pix=middle_pix) - # Offet extract parameters by location - nominal - shift_by_source_location(location, nominal_location, extract_params) + # Offset extract parameters by location - nominal + offset = location - nominal_location + log.info(f"Nominal location is {nominal_location:.2f}, " + f"so offset is {offset:.2f} pixels") + shift_by_offset(offset, extract_params, update_trace=False) + else: + middle_pix, middle_wl, location, trace = None, None, None, None + + # Store the trace, if computed + extract_params['trace'] = trace + + # Add an extra position offset if desired, from extract_params['position_offset'] + offset = extract_params.get('position_offset', 0.0) + if offset != 0.0: + log.info(f"Applying additional cross-dispersion offset {offset:.2f} pixels") + shift_by_offset(offset, extract_params, update_trace=True) # Make a spatial profile, including source shifts if necessary profile, lower_limit, upper_limit = box_profile( @@ -1812,8 +1970,8 @@ def create_extraction(input_model, slit, output_model, def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, smoothing_length=None, bkg_fit=None, bkg_order=None, log_increment=50, subtract_background=None, - use_source_posn=None, save_profile=False, - save_scene_model=False): + use_source_posn=None, position_offset=0.0, + save_profile=False, save_scene_model=False): """Extract all 1-D spectra from an input model. Parameters @@ -1848,6 +2006,9 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, If True, the target and background positions specified in the reference file (or the default position, if there is no reference file) will be shifted to account for source position offset. + position_offset : float + Number of pixels to shift the nominal source position in the + cross-dispersion direction. save_profile : bool If True, the spatial profiles created for the input model will be returned as ImageModels. If False, the return value is None. @@ -1943,8 +2104,9 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, save_profile=save_profile, save_scene_model=save_scene_model, smoothing_length=smoothing_length, bkg_fit=bkg_fit, bkg_order=bkg_order, + subtract_background=subtract_background, use_source_posn=use_source_posn, - subtract_background=subtract_background) + position_offset=position_offset) except ContinueError: continue @@ -1979,8 +2141,9 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, save_profile=save_profile, save_scene_model=save_scene_model, smoothing_length=smoothing_length, bkg_fit=bkg_fit, bkg_order=bkg_order, + subtract_background=subtract_background, use_source_posn=use_source_posn, - subtract_background=subtract_background) + position_offset=position_offset) except ContinueError: pass @@ -2017,8 +2180,9 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, save_profile=save_profile, save_scene_model=save_scene_model, smoothing_length=smoothing_length, bkg_fit=bkg_fit, bkg_order=bkg_order, + subtract_background=subtract_background, use_source_posn=use_source_posn, - subtract_background=subtract_background) + position_offset=position_offset) except ContinueError: pass diff --git a/jwst/extract_1d/extract_1d_step.py b/jwst/extract_1d/extract_1d_step.py index a613bedf26..3738ee51fc 100644 --- a/jwst/extract_1d/extract_1d_step.py +++ b/jwst/extract_1d/extract_1d_step.py @@ -162,13 +162,14 @@ class Extract1dStep(Step): apply_apcorr = boolean(default=True) # apply aperture corrections? use_source_posn = boolean(default=None) # use source coords to center extractions? + position_offset = float(default=0) # number of pixels to shift source trace in the cross-dispersion direction smoothing_length = integer(default=None) # background smoothing size bkg_fit = option("poly", "mean", "median", None, default=None) # background fitting type bkg_order = integer(default=None, min=0) # order of background polynomial fit log_increment = integer(default=50) # increment for multi-integration log messages save_profile = boolean(default=False) # save spatial profile to disk save_scene_model = boolean(default=False) # save flux model to disk - + center_xy = float_list(min=2, max=2, default=None) # IFU extraction x/y center ifu_autocen = boolean(default=False) # Auto source centering for IFU point source data. bkg_sigma_clip = float(default=3.0) # background sigma clipping threshold for IFU @@ -176,7 +177,7 @@ class Extract1dStep(Step): ifu_set_srctype = option("POINT", "EXTENDED", None, default=None) # user-supplied source type ifu_rscale = float(default=None, min=0.5, max=3) # Radius in terms of PSF FWHM to scale extraction radii ifu_covar_scale = float(default=1.0) # Scaling factor to apply to errors to account for IFU cube covariance - + soss_atoca = boolean(default=True) # use ATOCA algorithm soss_threshold = float(default=1e-2) # TODO: threshold could be removed from inputs. Its use is too specific now. soss_n_os = integer(default=2) # minimum oversampling factor of the underlying wavelength grid used when modeling trace. @@ -422,8 +423,9 @@ def process(self, input): self.log_increment, self.subtract_background, self.use_source_posn, + self.position_offset, self.save_profile, - self.save_scene_model + self.save_scene_model, ) # Set the step flag to complete in each model diff --git a/jwst/extract_1d/tests/conftest.py b/jwst/extract_1d/tests/conftest.py index 467f03f385..00456e991f 100644 --- a/jwst/extract_1d/tests/conftest.py +++ b/jwst/extract_1d/tests/conftest.py @@ -34,9 +34,24 @@ def simple_wcs_function(x, y): # Add a bounding box simple_wcs_function.bounding_box = wcs_bbox_from_shape(shape) - # Add a few expected attributes, so they can be monkeypatched as needed - simple_wcs_function.get_transform = None - simple_wcs_function.backward_transform = None + # Define a simple transform + def get_transform(*args, **kwargs): + def return_results(*args, **kwargs): + if len(args) == 2: + zeros = np.zeros(args[0].shape) + wave, _ = np.meshgrid(args[0], args[1]) + return zeros, zeros, wave + if len(args) == 3: + try: + nx = len(args[0]) + except TypeError: + nx = 1 + pix = np.arange(nx) + trace = np.ones(nx) + return pix, trace + return return_results + + simple_wcs_function.get_transform = get_transform simple_wcs_function.available_frames = [] return simple_wcs_function @@ -68,10 +83,17 @@ def simple_wcs_function(x, y): # Add a bounding box simple_wcs_function.bounding_box = wcs_bbox_from_shape(shape) - # Add a few expected attributes, so they can be monkeypatched as needed - simple_wcs_function.get_transform = None - simple_wcs_function.backward_transform = None - simple_wcs_function.available_frames = [] + # Mock a simple backward transform + def backward_transform(*args, **kwargs): + try: + nx = len(args[0]) + except TypeError: + nx = 1 + pix = np.arange(nx) + trace = np.ones(nx) + return trace, pix + + simple_wcs_function.backward_transform = backward_transform return simple_wcs_function diff --git a/jwst/extract_1d/tests/test_expected_skips.py b/jwst/extract_1d/tests/test_expected_skips.py index d3e3bbe239..2ba7a17a85 100644 --- a/jwst/extract_1d/tests/test_expected_skips.py +++ b/jwst/extract_1d/tests/test_expected_skips.py @@ -48,7 +48,7 @@ def test_expected_skip_niriss_soss_full(mock_niriss_full): def test_expected_skip_niriss_soss_f277w(mock_niriss_f277w): - + with mock_niriss_f277w as model: result = Extract1dStep().process(model) result2 = PhotomStep().process(result) diff --git a/jwst/extract_1d/tests/test_extract.py b/jwst/extract_1d/tests/test_extract.py index a66b19c55c..9e4dbd5baa 100644 --- a/jwst/extract_1d/tests/test_extract.py +++ b/jwst/extract_1d/tests/test_extract.py @@ -30,7 +30,8 @@ def extract1d_ref_dict(): {'id': 'slit5', 'bkg_coeff': None}, {'id': 'slit6', 'use_source_posn': True}, {'id': 'slit7', 'spectral_order': 20}, - {'id': 'S200A1'} + {'id': 'S200A1'}, + {'id': 'S1600A1', 'use_source_posn': False} ] ref_dict = {'apertures': apertures} return ref_dict @@ -58,6 +59,8 @@ def extract_defaults(): 'spectral_order': 1, 'src_coeff': None, 'subtract_background': False, + 'position_offset': 0.0, + 'trace': None, 'use_source_posn': False, 'xstart': 0, 'xstop': 49, @@ -197,13 +200,14 @@ def test_get_extract_parameters_no_match( def test_get_extract_parameters_source_posn_exptype( mock_nirspec_bots, extract1d_ref_dict, extract_defaults): input_model = mock_nirspec_bots + input_model.meta.exposure.type = 'NRS_LAMP' # match a bare entry params = ex.get_extract_parameters( extract1d_ref_dict, input_model, 'slit1', 1, input_model.meta, use_source_posn=None) - # use_source_posn is switched off for NRS_BRIGHTOBJ + # use_source_posn is switched off for unknown exptypes assert params['use_source_posn'] is False @@ -749,6 +753,45 @@ def test_box_profile_from_width(extract_defaults, dispaxis): assert np.all(profile[8:] == 0.0) +@pytest.mark.parametrize('dispaxis', [1, 2]) +def test_box_profile_from_trace(extract_defaults, dispaxis): + shape = (10, 10) + wl_array = np.empty(shape) + wl_array[:] = np.linspace(3, 5, 10) + + params = extract_defaults + params['dispaxis'] = dispaxis + + # Set a linear trace + params['trace'] = np.arange(10) + 1.5 + + # Set the width to 4 pixels + params['extract_width'] = 4.0 + + # Make the profile + profile, lower, upper = ex.box_profile( + shape, extract_defaults, wl_array, return_limits=True) + if dispaxis == 2: + profile = profile.T + + expected = [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]] + + assert np.allclose(profile, expected) + + # upper and lower limits are averages + assert np.isclose(lower, 4.5) + assert np.isclose(upper, 7.5) + + @pytest.mark.parametrize('middle', [None, 7]) @pytest.mark.parametrize('dispaxis', [1, 2]) def test_aperture_center(middle, dispaxis): @@ -818,27 +861,19 @@ def test_location_from_wcs_nirspec( monkeypatch, mock_nirspec_fs_one_slit, resampled, is_slit, missing_bbox): model = mock_nirspec_fs_one_slit - # monkey patch in a transform for the wcs - def slit2det(*args, **kwargs): - def return_one(*args, **kwargs): - return 0.0, 1.0 - return return_one - - monkeypatch.setattr(model.meta.wcs, 'get_transform', slit2det) - if not resampled: - # also mock available frames, so it looks like unresampled cal data + # mock available frames, so it looks like unresampled cal data monkeypatch.setattr(model.meta.wcs, 'available_frames', ['gwa']) if missing_bbox: - # also mock a missing bounding box - should have same results + # mock a missing bounding box - should have same results # for the test data monkeypatch.setattr(model.meta.wcs, 'bounding_box', None) if is_slit: - middle, middle_wl, location = ex.location_from_wcs(model, model) + middle, middle_wl, location, trace = ex.location_from_wcs(model, model) else: - middle, middle_wl, location = ex.location_from_wcs(model, None) + middle, middle_wl, location, trace = ex.location_from_wcs(model, None) # middle pixel is center of dispersion axis assert middle == int((model.data.shape[1] - 1) / 2) @@ -849,6 +884,9 @@ def return_one(*args, **kwargs): # location is 1.0 - from the mocked transform function assert location == 1.0 + # trace is the same, in an array + assert np.all(trace == 1.0) + @pytest.mark.parametrize('is_slit', [True, False]) def test_location_from_wcs_miri(monkeypatch, mock_miri_lrs_fs, is_slit): @@ -862,11 +900,17 @@ def return_one(*args, **kwargs): monkeypatch.setattr(model.meta.wcs, 'backward_transform', radec2det()) + # mock the trace function + def mock_trace(*args, **kwargs): + return np.full(model.data.shape[-2], 1.0) + + monkeypatch.setattr(ex, '_miri_trace_from_wcs', mock_trace) + # Get the slit center from the WCS if is_slit: - middle, middle_wl, location = ex.location_from_wcs(model, model) + middle, middle_wl, location, trace = ex.location_from_wcs(model, model) else: - middle, middle_wl, location = ex.location_from_wcs(model, None) + middle, middle_wl, location, trace = ex.location_from_wcs(model, None) # middle pixel is center of dispersion axis assert middle == int((model.data.shape[0] - 1) / 2) @@ -877,12 +921,18 @@ def return_one(*args, **kwargs): # location is 1.0 - from the mocked transform function assert location == 1.0 + # trace is the same, in an array + assert np.all(trace == 1.0) + def test_location_from_wcs_missing_data(mock_miri_lrs_fs, log_watcher): + model = mock_miri_lrs_fs + model.meta.wcs.backward_transform = None + # model is missing WCS information - None values are returned log_watcher.message = "Dithered pointing location not found" - result = ex.location_from_wcs(mock_miri_lrs_fs, None) - assert result == (None, None, None) + result = ex.location_from_wcs(model, None) + assert result == (None, None, None, None) log_watcher.assert_seen() @@ -890,7 +940,7 @@ def test_location_from_wcs_wrong_exptype(mock_niriss_soss, log_watcher): # model is not a handled exposure type log_watcher.message = "Source position cannot be found for EXP_TYPE" result = ex.location_from_wcs(mock_niriss_soss, None) - assert result == (None, None, None) + assert result == (None, None, None, None) log_watcher.assert_seen() @@ -909,7 +959,7 @@ def return_one(*args, **kwargs): # WCS transform returns NaN for the location log_watcher.message = "Source position could not be determined" result = ex.location_from_wcs(model, None) - assert result == (None, None, None) + assert result == (None, None, None, None) log_watcher.assert_seen() @@ -925,59 +975,103 @@ def return_one(*args, **kwargs): monkeypatch.setattr(model.meta.wcs, 'get_transform', slit2det) + # mock the trace function + def mock_trace(*args, **kwargs): + return np.full(model.data.shape[-1], 1.0) + + monkeypatch.setattr(ex, '_nirspec_trace_from_wcs', mock_trace) + # WCS transform a value outside the bounding box log_watcher.message = "outside the bounding box" result = ex.location_from_wcs(model, None) - assert result == (None, None, None) + assert result == (None, None, None, None) log_watcher.assert_seen() -def test_shift_by_source_location_horizontal(extract_defaults): - location = 12.5 - nominal_location = 15.0 - offset = location - nominal_location +def test_shift_by_offset_horizontal(extract_defaults): + offset = 2.5 extract_params = extract_defaults.copy() extract_params['dispaxis'] = 1 + extract_params['position_offset'] = offset - ex.shift_by_source_location(location, nominal_location, extract_params) + ex.shift_by_offset(offset, extract_params) assert extract_params['xstart'] == extract_defaults['xstart'] assert extract_params['xstop'] == extract_defaults['xstop'] assert extract_params['ystart'] == extract_defaults['ystart'] + offset assert extract_params['ystop'] == extract_defaults['ystop'] + offset -def test_shift_by_source_location_vertical(extract_defaults): - location = 12.5 - nominal_location = 15.0 - offset = location - nominal_location +def test_shift_by_offset_vertical(extract_defaults): + offset = 2.5 extract_params = extract_defaults.copy() extract_params['dispaxis'] = 2 + extract_params['position_offset'] = offset - ex.shift_by_source_location(location, nominal_location, extract_params) + ex.shift_by_offset(offset, extract_params) assert extract_params['xstart'] == extract_defaults['xstart'] + offset assert extract_params['xstop'] == extract_defaults['xstop'] + offset assert extract_params['ystart'] == extract_defaults['ystart'] assert extract_params['ystop'] == extract_defaults['ystop'] -def test_shift_by_source_location_coeff(extract_defaults): - location = 6.5 - nominal_location = 4.0 - offset = location - nominal_location +def test_shift_by_offset_coeff(extract_defaults): + offset = 2.5 extract_params = extract_defaults.copy() extract_params['dispaxis'] = 1 + extract_params['position_offset'] = offset extract_params['src_coeff'] = [[2.5, 1.0], [6.5, 1.0]] extract_params['bkg_coeff'] = [[-0.5], [3.0], [6.0], [9.5]] - ex.shift_by_source_location(location, nominal_location, extract_params) + ex.shift_by_offset(offset, extract_params) assert extract_params['src_coeff'] == [[2.5 + offset, 1.0], [6.5 + offset, 1.0]] assert extract_params['bkg_coeff'] == [[-0.5 + offset], [3.0 + offset], [6.0 + offset], [9.5 + offset]] +def test_shift_by_offset_trace(extract_defaults): + offset = 2.5 + + extract_params = extract_defaults.copy() + extract_params['dispaxis'] = 1 + extract_params['position_offset'] = offset + extract_params['trace'] = np.arange(10, dtype=float) + + ex.shift_by_offset(offset, extract_params, update_trace=True) + assert np.all(extract_params['trace'] == np.arange(10) + offset) + + +def test_shift_by_offset_trace_no_update(extract_defaults): + offset = 2.5 + + extract_params = extract_defaults.copy() + extract_params['dispaxis'] = 1 + extract_params['position_offset'] = offset + extract_params['trace'] = np.arange(10, dtype=float) + + ex.shift_by_offset(offset, extract_params, update_trace=False) + assert np.all(extract_params['trace'] == np.arange(10)) + + +def test_nirspec_trace_from_wcs(mock_nirspec_fs_one_slit): + model = mock_nirspec_fs_one_slit + trace = ex._nirspec_trace_from_wcs(model.data.shape, model.meta.wcs.bounding_box, + model.meta.wcs, 1.0, 1.0) + # mocked model contains some mock transforms as well - all ones are expected + assert np.all(trace == np.ones(model.data.shape[-1])) + + +def test_miri_trace_from_wcs(mock_miri_lrs_fs): + model = mock_miri_lrs_fs + trace = ex._miri_trace_from_wcs(model.data.shape, model.meta.wcs.bounding_box, + model.meta.wcs, 1.0, 1.0) + + # mocked model contains some mock transforms as well - all ones are expected + assert np.all(trace == np.ones(model.data.shape[-1])) + + @pytest.mark.parametrize('is_slit', [True, False]) def test_define_aperture_nirspec(mock_nirspec_fs_one_slit, extract_defaults, is_slit): model = mock_nirspec_fs_one_slit @@ -1093,7 +1187,7 @@ def test_define_aperture_use_source(monkeypatch, mock_nirspec_fs_one_slit, extra # mock the source location function def mock_source_location(*args): - return 24, 7.74, 9.5 + return 24, 7.74, 9.5, np.full(model.data.shape[-1], 9.5) monkeypatch.setattr(ex, 'location_from_wcs', mock_source_location) @@ -1109,6 +1203,24 @@ def mock_source_location(*args): assert np.all(profile[13:] == 0.0) +def test_define_aperture_extra_offset(mock_nirspec_fs_one_slit, extract_defaults): + model = mock_nirspec_fs_one_slit + extract_defaults['dispaxis'] = 1 + slit = None + exptype = 'NRS_FIXEDSLIT' + + extract_defaults['position_offset'] = 2.0 + + result = ex.define_aperture(model, slit, extract_defaults, exptype) + _, _, _, profile, _, limits = result + assert profile.shape == model.data.shape + + # Default profile is shifted 2 pixels up + assert np.all(profile[:2] == 0.0) + assert np.all(profile[2:] == 1.0) + assert limits == (2, model.data.shape[0] + 1, 0, model.data.shape[1] - 1) + + def test_extract_one_slit_horizontal(mock_nirspec_fs_one_slit, extract_defaults, simple_profile, background_profile): # update parameters to subtract background @@ -1314,7 +1426,8 @@ def test_create_extraction_missing_wavelengths(create_extraction_inputs, log_wat model.wavelength = np.full_like(model.data, np.nan) log_watcher.message = 'Spectrum is empty; no valid data' with pytest.raises(ex.ContinueError): - ex.create_extraction(*create_extraction_inputs) + with pytest.warns(RuntimeWarning, match='All-NaN'): + ex.create_extraction(*create_extraction_inputs) log_watcher.assert_seen() @@ -1336,9 +1449,11 @@ def test_create_extraction_one_int(create_extraction_inputs, mock_nirspec_bots, model = mock_nirspec_bots model.data = model.data[0].reshape(1, *model.data.shape[-2:]) create_extraction_inputs[0] = model + create_extraction_inputs[4] = 'S1600A1' log_watcher.message = '1 integration done' - ex.create_extraction(*create_extraction_inputs, log_increment=1) + ex.create_extraction( + *create_extraction_inputs, log_increment=1) output_model = create_extraction_inputs[2] assert len(output_model.spec) == 1 log_watcher.assert_seen() @@ -1347,6 +1462,7 @@ def test_create_extraction_one_int(create_extraction_inputs, mock_nirspec_bots, def test_create_extraction_log_increment( create_extraction_inputs, mock_nirspec_bots, log_watcher): create_extraction_inputs[0] = mock_nirspec_bots + create_extraction_inputs[4] = 'S1600A1' # all integrations are logged log_watcher.message = '... 9 integrations done' @@ -1365,7 +1481,7 @@ def test_create_extraction_use_source( # mock the source location function def mock_source_location(*args): - return 24, 7.74, 9.5 + return 24, 7.74, 9.5, np.full(model.data.shape[-1], 9.5) monkeypatch.setattr(ex, 'location_from_wcs', mock_source_location) @@ -1377,12 +1493,42 @@ def mock_source_location(*args): # source position is used log_watcher.message = 'Aperture start/stop: -15' else: - # If False, source position is not used + # If False, source position is not used log_watcher.message = 'Aperture start/stop: 0' ex.create_extraction(*create_extraction_inputs, use_source_posn=use_source) log_watcher.assert_seen() +@pytest.mark.parametrize('extract_width', [None, 7]) +def test_create_extraction_use_trace( + monkeypatch, create_extraction_inputs, mock_nirspec_bots, + extract_width, log_watcher): + model = mock_nirspec_bots + create_extraction_inputs[0] = model + aper = create_extraction_inputs[3]['apertures'] + create_extraction_inputs[4] = 'S1600A1' + for i in range(len(aper)): + if aper[i]['id'] == 'S1600A1': + aper[i]['use_source_posn'] = True + aper[i]['extract_width'] = extract_width + aper[i]['position_offset'] = 0 + + # mock the source location function + def mock_source_location(*args): + return 24, 7.74, 25, np.full(model.data.shape[-1], 25) + + monkeypatch.setattr(ex, 'location_from_wcs', mock_source_location) + if extract_width is not None: + # If explicitly set to True, or unspecified + source type is POINT, + # source position is used + log_watcher.message = 'aperture start/stop from trace: 22' + else: + # If False, source trace is not used + log_watcher.message = 'Aperture start/stop: 0' + ex.create_extraction(*create_extraction_inputs) + log_watcher.assert_seen() + + def test_run_extract1d(mock_nirspec_mos): model = mock_nirspec_mos output_model, profile_model, scene_model = ex.run_extract1d(model) @@ -1429,7 +1575,6 @@ def test_run_extract1d_save_cube_scene(mock_nirspec_bots): scene_model.close() - def test_run_extract1d_tso(mock_nirspec_bots): model = mock_nirspec_bots output_model, _, _ = ex.run_extract1d(model) diff --git a/jwst/extract_1d/tests/test_extract_1d_step.py b/jwst/extract_1d/tests/test_extract_1d_step.py index 98e0aaeb63..0493679097 100644 --- a/jwst/extract_1d/tests/test_extract_1d_step.py +++ b/jwst/extract_1d/tests/test_extract_1d_step.py @@ -52,7 +52,8 @@ def test_extract_nirspec_mos_multi_slit(mock_nirspec_mos, simple_wcs): def test_extract_nirspec_bots(mock_nirspec_bots, simple_wcs): - result = Extract1dStep.call(mock_nirspec_bots, apply_apcorr=False) + result = Extract1dStep.call( + mock_nirspec_bots, apply_apcorr=False, use_source_posn=False) assert result.meta.cal_step.extract_1d == 'COMPLETE' assert (result.spec[0].name == 'S1600A1') diff --git a/jwst/regtest/test_nirspec_bots_extract1d.py b/jwst/regtest/test_nirspec_bots_extract1d.py index 148b60e41f..d04c7cfbfc 100644 --- a/jwst/regtest/test_nirspec_bots_extract1d.py +++ b/jwst/regtest/test_nirspec_bots_extract1d.py @@ -20,6 +20,7 @@ def run_extract(rtdata_module, request): # Run the calwebb_spec2 pipeline; args = ["extract_1d", rtdata.input, f"--override_extract1d={ref_file}", + "--use_source_posn=False", "--suffix=x1dints"] Step.from_cmdline(args)