Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-2546: updates and fixes for wfss_contam step #8417

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions jwst/assign_wcs/niriss.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,11 @@ def wfss(input_model, reference_files):
# Get the disperser parameters which are defined as a model for each
# spectral order
with NIRISSGrismModel(reference_files['specwcs']) as f:
dispx = f.dispx
dispy = f.dispy
displ = f.displ
invdispl = f.invdispl
orders = f.orders
dispx = f.dispx.instance
dispy = f.dispy.instance
displ = f.displ.instance
invdispl = f.invdispl.instance
orders = f.orders.instance
fwcpos_ref = f.fwcpos_ref

# This is the actual rotation from the input model
Expand Down
152 changes: 121 additions & 31 deletions jwst/wfss_contam/disperse.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,123 @@
from functools import partial
import numpy as np
from typing import Callable, Sequence
from astropy.wcs import WCS

from scipy.interpolate import interp1d
import warnings

from ..lib.winclip import get_clipped_pixels
from .sens1d import create_1d_sens


def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax,
sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis,
oversample_factor=2, extrapolate_sed=False, xoffset=0,
yoffset=0):
def flat_lam(fluxes: np.ndarray, lams: np.ndarray) -> np.ndarray:
'''
Parameters
----------
x : float
x-coordinate of the pixel.
lams : float array
Array of wavelengths corresponding to the fluxes (flxs) for each pixel.
One wavelength per direct image, so can be a single value.

Returns
-------
lams : float array
Array of wavelengths corresponding to the fluxes (flxs) for each pixel.
One wavelength per direct image, so can be a single value.
'''
return fluxes[0]


def flux_interpolator_injector(lams: np.ndarray,
flxs: np.ndarray,
extrapolate_sed: bool,
) -> Callable[[float], float]:
'''
Parameters
----------
lams : float array
Array of wavelengths corresponding to the fluxes (flxs) for each pixel.
One wavelength per direct image, so can be a single value.
flxs : float array
Array of fluxes (flam) for the pixels contained in x0, y0. If a single
direct image is in use, this will be a single value.
extrapolate_sed : bool
Whether to allow for the SED of the object to be extrapolated when it does not fully cover the
needed wavelength range. Default if False.

Returns
-------
flux : function
Function that returns the flux at a given wavelength. If only one direct image is in use, this
function will always return the same value
'''

if len(lams) > 1:
# If we have direct image flux values from more than one filter (lams),
# we have the option to extrapolate the fluxes outside the
# wavelength range of the direct images
if extrapolate_sed is False:
return interp1d(lams, flxs, fill_value=0., bounds_error=False)
else:
return interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False)
else:
# If we only have flux from one wavelength, just use that
# single flux value at all wavelengths
return partial(flat_lam, flxs)


def determine_wl_spacing(dw: float,
lams: np.ndarray,
oversample_factor: int,
) -> float:
'''
Use a natural wavelength scale or the wavelength scale of the input SED/spectrum,
whichever is smaller, divided by oversampling requested

Parameters
----------
dw : float
The natural wavelength scale of the grism image
lams : float array
Array of wavelengths corresponding to the fluxes (flxs) for each pixel.
One wavelength per direct image, so can be a single value.
oversample_factor : int
The amount of oversampling

Returns
-------
dlam : float
The wavelength spacing to use for the dispersed pixels
'''
#
if len(lams) > 1:
input_dlam = np.median(lams[1:] - lams[:-1])
if input_dlam < dw:
return input_dlam / oversample_factor
return dw / oversample_factor


def dispersed_pixel(x0: np.ndarray,
y0: np.ndarray,
width: float,
height: float,
lams: np.ndarray,
flxs: np.ndarray,
order: int,
wmin: float,
wmax: float,
sens_waves: np.ndarray,
sens_resp: np.ndarray,
seg_wcs: WCS,
grism_wcs: WCS,
ID: int,
naxis: Sequence[int],
oversample_factor: int = 2,
extrapolate_sed: bool = False,
xoffset: float = 0,
yoffset: float = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
"""
This function take a list of pixels and disperses them using the information contained
in the grism image WCS object and returns a list of dispersed pixels and fluxes.
Expand Down Expand Up @@ -83,20 +191,8 @@ def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax,
sky_to_imgxy = grism_wcs.get_transform('world', 'detector')
imgxy_to_grismxy = grism_wcs.get_transform('detector', 'grism_detector')

# Setup function for retrieving flux values at each dispersed wavelength
if len(lams) > 1:
# If we have direct image flux values from more than one filter (lambda),
# we have the option to extrapolate the fluxes outside the
# wavelength range of the direct images
if extrapolate_sed is False:
flux = interp1d(lams, flxs, fill_value=0., bounds_error=False)
else:
flux = interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False)
else:
# If we only have flux from one lambda, just use that
# single flux value at all wavelengths
def flux(x):
return flxs[0]
# Set up function for retrieving flux values at each dispersed wavelength
flux_interpolator = flux_interpolator_injector(lams, flxs, extrapolate_sed)

# Get x/y positions in the grism image corresponding to wmin and wmax:
# Start with RA/Dec of the input pixel position in segmentation map,
Expand All @@ -110,19 +206,9 @@ def flux(x):
dxw = xwmax - xwmin
dyw = ywmax - ywmin

# Compute the delta-wave per pixel
dw = np.abs((wmax - wmin) / (dyw - dxw))

# Use a natural wavelength scale or the wavelength scale of the input SED/spectrum,
# whichever is smaller, divided by oversampling requested
input_dlam = np.median(lams[1:] - lams[:-1])
if input_dlam < dw:
dlam = input_dlam / oversample_factor
else:
# this value gets used when we only have 1 direct image wavelength
dlam = dw / oversample_factor

# Create list of wavelengths on which to compute dispersed pixels
dw = np.abs((wmax - wmin) / (dyw - dxw))
dlam = determine_wl_spacing(dw, lams, oversample_factor)
lambdas = np.arange(wmin, wmax + dlam, dlam)
n_lam = len(lambdas)

Expand Down Expand Up @@ -161,7 +247,11 @@ def flux(x):
# values are naturally in units of physical fluxes, so we divide out
# the sensitivity (flux calibration) values to convert to units of
# countrate (DN/s).
counts = flux(lams) * areas / sens
# flux_interpolator(lams) is either single-valued (for a single direct image)
# or an array of the same length as lams (for multiple direct images in different filters)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, message="divide by zero")
counts = flux_interpolator(lams) * areas / (sens * oversample_factor)
counts[no_cal] = 0. # set to zero where no flux cal info available

return xs, ys, areas, lams, counts, ID
Loading
Loading