From 1902445d9d6fd0968c847daba965cef30ce17bfa Mon Sep 17 00:00:00 2001 From: ojustino Date: Fri, 30 Sep 2022 09:57:35 -0400 Subject: [PATCH] Better handled Spectrum1D images across classes --- specreduce/background.py | 17 +++++++++------ specreduce/extract.py | 45 ++++++++++++++++++++++++---------------- specreduce/tracing.py | 20 ++++++++++++++---- 3 files changed, 54 insertions(+), 28 deletions(-) diff --git a/specreduce/background.py b/specreduce/background.py index a2829622..d8b7def3 100644 --- a/specreduce/background.py +++ b/specreduce/background.py @@ -26,7 +26,7 @@ class Background: Parameters ---------- - image : `~astropy.nddata.NDData` or array-like + image : `~astropy.nddata.NDData`-like or array-like image with 2-D spectral image data traces : List list of trace objects (or integers to define FlatTraces) to @@ -59,7 +59,7 @@ def __post_init__(self): Parameters ---------- - image : `~astropy.nddata.NDData` or array-like + image : `~astropy.nddata.NDData`-like or array-like image with 2-D spectral image data traces : List list of trace objects (or integers to define FlatTraces) to @@ -85,6 +85,11 @@ def _to_trace(trace): raise ValueError('trace_object.trace_pos must be >= 1') return trace + if isinstance(self.image, NDData): + # NOTE: should the NDData structure instead be preserved? + # (NDData includes Spectrum1D under its umbrella) + self.image = self.image.data + bkg_wimage = np.zeros_like(self.image, dtype=np.float64) for trace in self.traces: trace = _to_trace(trace) @@ -132,7 +137,7 @@ def two_sided(cls, image, trace_object, separation, **kwargs): Parameters ---------- - image : nddata-compatible image + image : `~astropy.nddata.NDData`-like or array-like image with 2-D spectral image data trace_object: Trace estimated trace of the spectrum to center the background traces @@ -165,7 +170,7 @@ def one_sided(cls, image, trace_object, separation, **kwargs): Parameters ---------- - image : nddata-compatible image + image : `~astropy.nddata.NDData`-like or array-like image with 2-D spectral image data trace_object: Trace estimated trace of the spectrum to center the background traces @@ -192,7 +197,7 @@ def bkg_image(self, image=None): Parameters ---------- - image : nddata-compatible image or None + image : `~astropy.nddata.NDData`-like, array-like, or None image with 2-D spectral image data. If None, will use ``image`` passed to extract the background. @@ -211,7 +216,7 @@ def sub_image(self, image=None): Parameters ---------- - image : nddata-compatible image or None + image : `~astropy.nddata.NDData`-like, array-like, or None image with 2-D spectral image data. If None, will use ``image`` passed to extract the background. diff --git a/specreduce/extract.py b/specreduce/extract.py index 2537d7ba..768c6160 100644 --- a/specreduce/extract.py +++ b/specreduce/extract.py @@ -116,15 +116,15 @@ class BoxcarExtract(SpecreduceOperation): Parameters ---------- - image : nddata-compatible image + image : `~astropy.nddata.NDData`-like or array-like, required image with 2-D spectral image data - trace_object : Trace + trace_object : Trace, required trace object - width : float + width : float, optional width of extraction aperture in pixels - disp_axis : int + disp_axis : int, optional dispersion axis - crossdisp_axis : int + crossdisp_axis : int, optional cross-dispersion axis Returns @@ -150,15 +150,15 @@ def __call__(self, image=None, trace_object=None, width=None, Parameters ---------- - image : nddata-compatible image + image : `~astropy.nddata.NDData`-like or array-like, required image with 2-D spectral image data - trace_object : Trace + trace_object : Trace, required trace object - width : float + width : float, optional width of extraction aperture in pixels [default: 5] - disp_axis : int + disp_axis : int, optional dispersion axis [default: 1] - crossdisp_axis : int + crossdisp_axis : int, optional cross-dispersion axis [default: 0] @@ -174,25 +174,33 @@ def __call__(self, image=None, trace_object=None, width=None, disp_axis = disp_axis if disp_axis is not None else self.disp_axis crossdisp_axis = crossdisp_axis if crossdisp_axis is not None else self.crossdisp_axis + # handle image processing based on its type + if isinstance(image, Spectrum1D): + img = image.data + unit = image.unit + else: + img = image + unit = getattr(image, 'unit', u.DN) + # TODO: this check can be removed if/when implemented as a check in FlatTrace if isinstance(trace_object, FlatTrace): if trace_object.trace_pos < 1: raise ValueError('trace_object.trace_pos must be >= 1') # weight image to use for extraction - wimage = _ap_weight_image( + wimg = _ap_weight_image( trace_object, width, disp_axis, crossdisp_axis, - image.shape) + img.shape) # extract - ext1d = np.sum(image * wimage, axis=crossdisp_axis) + ext1d = np.sum(img * wimg, axis=crossdisp_axis) * unit - # TODO: add wavelenght units, uncertainty and mask to spectrum1D object - spec = Spectrum1D(spectral_axis=np.arange(len(ext1d)) * u.pixel, - flux=ext1d * getattr(image, 'unit', u.DN)) + # TODO: add wavelength units, uncertainty and mask to Spectrum1D object + pixels = np.arange(ext1d.shape[crossdisp_axis]) * u.pixel + spec = Spectrum1D(spectral_axis=pixels, flux=ext1d) return spec @@ -206,7 +214,7 @@ class HorneExtract(SpecreduceOperation): Parameters ---------- - image : `~astropy.nddata.NDData` or array-like, required + image : `~astropy.nddata.NDData`-like or array-like, required The input 2D spectrum from which to extract a source. An NDData object must specify uncertainty and a mask. An array requires use of the ``variance``, ``mask``, & ``unit`` arguments. @@ -269,7 +277,7 @@ def __call__(self, image=None, trace_object=None, Parameters ---------- - image : `~astropy.nddata.NDData` or array-like, required + image : `~astropy.nddata.NDData`-like or array-like, required The input 2D spectrum from which to extract a source. An NDData object must specify uncertainty and a mask. An array requires use of the ``variance``, ``mask``, & ``unit`` arguments. @@ -322,6 +330,7 @@ def __call__(self, image=None, trace_object=None, # handle image and associated data based on image's type if isinstance(image, NDData): + # (NDData includes Spectrum1D under its umbrella) img = np.ma.array(image.data, mask=image.mask) unit = image.unit if image.unit is not None else u.Unit() diff --git a/specreduce/tracing.py b/specreduce/tracing.py index 6b2557ca..a78ad7d6 100644 --- a/specreduce/tracing.py +++ b/specreduce/tracing.py @@ -5,9 +5,10 @@ import warnings from astropy.modeling import fitting, models -from astropy.nddata import CCDData, NDData +from astropy.nddata import NDData from astropy.stats import gaussian_sigma_to_fwhm from scipy.interpolate import UnivariateSpline +from specutils import Spectrum1D import numpy as np __all__ = ['Trace', 'FlatTrace', 'ArrayTrace', 'KosmosTrace'] @@ -20,7 +21,7 @@ class Trace: Parameters ---------- - image : `~astropy.nddata.CCDData` + image : `~astropy.nddata.NDData`-like or array-like, required Image to be traced Properties @@ -28,7 +29,7 @@ class Trace: shape : tuple Shape of the array describing the trace """ - image: CCDData + image: NDData def __post_init__(self): self.trace_pos = self.image.shape[0] / 2 @@ -37,6 +38,11 @@ def __post_init__(self): def __getitem__(self, i): return self.trace[i] + def _parse_image(self): + if isinstance(self.image, Spectrum1D): + # NOTE: should the Spectrum1D structure instead be preserved? + self.image = self.image.data + @property def shape(self): return self.trace.shape @@ -95,6 +101,8 @@ class FlatTrace(Trace): trace_pos: float def __post_init__(self): + super()._parse_image() + self.set_position(self.trace_pos) def set_position(self, trace_pos): @@ -124,6 +132,8 @@ class ArrayTrace(Trace): trace: np.ndarray def __post_init__(self): + super()._parse_image() + nx = self.image.shape[1] nt = len(self.trace) if nt != nx: @@ -158,7 +168,7 @@ class KosmosTrace(Trace): Parameters ---------- - image : `~astropy.nddata.NDData` or array-like, required + image : `~astropy.nddata.NDData`-like or array-like, required The image over which to run the trace. Assumes cross-dispersion (spatial) direction is axis 0 and dispersion (wavelength) direction is axis 1. @@ -200,6 +210,8 @@ class KosmosTrace(Trace): _disp_axis = 1 def __post_init__(self): + super()._parse_image() + # handle multiple image types and mask uncaught invalid values if isinstance(self.image, NDData): img = np.ma.masked_invalid(np.ma.masked_array(self.image.data,