diff --git a/specreduce/extract.py b/specreduce/extract.py index b792be88..af3020a8 100644 --- a/specreduce/extract.py +++ b/specreduce/extract.py @@ -117,6 +117,34 @@ def _ap_weight_image(trace, width, disp_axis, crossdisp_axis, image_shape): return wimage +def _align_along_trace(img, trace_array, disp_axis=1, crossdisp_axis=0): + """ + Given an arbitrary trace ``trace_array`` (an np.ndarray), roll + all columns of ``nddata`` to shift the NDData's pixels nearest + to the trace to the center of the spatial dimension of the + NDData. + """ + # TODO: this workflow does not support extraction for >2D spectra + if not (disp_axis == 1 and crossdisp_axis == 0): + # take the transpose to ensure the rows are the cross-disp axis: + img = img.T + + n_rows, n_cols = img.shape + + # indices of all columns, in their original order + rows = np.broadcast_to(np.arange(n_rows)[:, None], img.shape) + cols = np.broadcast_to(np.arange(n_cols), img.shape) + + # we want to "roll" each column so that the trace sits in + # the central row of the final image + shifts = trace_array.astype(int) - n_rows // 2 + + # we wrap the indices so we don't index out of bounds + shifted_rows = np.mod(rows + shifts[None, :], n_rows) + + return img[shifted_rows, cols] + + @dataclass class BoxcarExtract(SpecreduceOperation): """ @@ -462,6 +490,23 @@ def __call__(self, image=None, trace_object=None, img = np.ma.masked_array(self.image.data, or_mask) mask = img.mask + # If the trace is not flat, shift the rows in each column + # so the image is aligned along the trace: + if isinstance(trace_object, FlatTrace): + mean_init_guess = trace_object.trace + else: + img = _align_along_trace( + img, + trace_object.trace, + disp_axis=disp_axis, + crossdisp_axis=crossdisp_axis + ) + # Choose the initial guess for the mean of + # the Gaussian profile: + mean_init_guess = np.broadcast_to( + img.shape[crossdisp_axis] // 2, img.shape[disp_axis] + ) + # co-add signal in each image column ncols = img.shape[crossdisp_axis] xd_pixels = np.arange(ncols) # y plot dir / x spec dir @@ -483,7 +528,8 @@ def __call__(self, image=None, trace_object=None, norms = [] for col_pix in range(img.shape[disp_axis]): # set gaussian model's mean as column's corresponding trace value - fit_ext_kernel.mean_0 = trace_object.trace[col_pix] + fit_ext_kernel.mean_0 = mean_init_guess[col_pix] + # NOTE: support for variable FWHMs forthcoming and would be here # fit compound model to column diff --git a/specreduce/tests/test_extract.py b/specreduce/tests/test_extract.py index c711f2e3..8bd54d83 100644 --- a/specreduce/tests/test_extract.py +++ b/specreduce/tests/test_extract.py @@ -3,8 +3,12 @@ import astropy.units as u from astropy.nddata import CCDData, VarianceUncertainty, UnknownUncertainty +from astropy.tests.helper import assert_quantity_allclose +from astropy.utils.exceptions import AstropyUserWarning -from specreduce.extract import BoxcarExtract, HorneExtract, OptimalExtract +from specreduce.extract import ( + BoxcarExtract, HorneExtract, OptimalExtract, _align_along_trace +) from specreduce.tracing import FlatTrace, ArrayTrace @@ -149,3 +153,47 @@ def test_horne_variance_errors(): # object doesn't have those attributes (e.g., numpy and Quantity arrays) ext = extract(image=image.data, variance=err, mask=image.mask, unit=u.Jy) + + +def test_horne_non_flat_trace(): + # create a synthetic "2D spectrum" and its non-flat trace + n_rows, n_cols = (10, 50) + original = np.zeros((n_rows, n_cols)) + original[n_rows // 2] = 1 + + # create small offsets along each column to specify a non-flat trace + trace_offset = np.polyval([2e-3, -0.01, 0], np.arange(n_cols)).astype(int) + exact_trace = n_rows // 2 - trace_offset + + # re-index the array with the offsets applied to the trace (make it non-flat): + rows = np.broadcast_to(np.arange(n_rows)[:, None], original.shape) + cols = np.broadcast_to(np.arange(n_cols), original.shape) + roll_rows = np.mod(rows + trace_offset[None, :], n_rows) + rolled = original[roll_rows, cols] + + # all zeros are treated as non-weighted (give non-zero fluxes) + err = 0.1 * np.ones_like(rolled) + mask = np.zeros_like(rolled).astype(bool) + + # unroll the trace using the Horne extract utility function for alignment: + unrolled = _align_along_trace(rolled, n_rows // 2 - trace_offset) + + # ensure that mask is correctly unrolled back to its original alignment: + np.testing.assert_allclose(unrolled, original) + + # These synthetic extractions don't fit well with a Gaussian, so will pass warning: + with pytest.warns(AstropyUserWarning, match="The fit may be unsuccessful"): + # Extract the spectrum from the non-flat image+trace + extract_non_flat = HorneExtract( + rolled, ArrayTrace(rolled, exact_trace), + variance=err, mask=mask, unit=u.Jy + )() + + # Also extract the spectrum from the image after alignment with a flat trace + extract_flat = HorneExtract( + unrolled, FlatTrace(unrolled, n_rows // 2), + variance=err, mask=mask, unit=u.Jy + )() + + # ensure both extractions are equivalent: + assert_quantity_allclose(extract_non_flat.flux, extract_flat.flux)