Skip to content

Commit

Permalink
feat: add correlation mode to template_correlation (#1114)
Browse files Browse the repository at this point in the history
* feat: add correlation mode to template_correlation

* Apply suggestions from code review

* Clean up trailing whitespace

---------

Co-authored-by: Ricky O'Steen <[email protected]>
  • Loading branch information
zhukgleb and rosteen authored Dec 8, 2023
1 parent f678dbd commit fb5bd08
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions specutils/analysis/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from astropy.nddata import StdDevUncertainty
from astropy.units import Quantity
from scipy.signal.windows import tukey
from scipy.signal import correlate

from ..manipulation import LinearInterpolatedResampler
from .. import Spectrum1D
Expand All @@ -14,7 +15,7 @@


def template_correlate(observed_spectrum, template_spectrum, lag_units=_KMS,
apodization_window=0.5, resample=True):
apodization_window=0.5, resample=True, method="direct"):
"""
Compute cross-correlation of the observed and template spectra.
Expand Down Expand Up @@ -49,6 +50,11 @@ def template_correlate(observed_spectrum, template_spectrum, lag_units=_KMS,
``template_logwl_resample(spectrum, template, delta_log_wavelength=.1)``.
If False, *no* resampling is performed (and the user is responsible for
a sensible resampling).
method: str
If you choose "FFT", the correlation will be done through the use
of convolution and will be calculated faster (for small spectral
resolutions it is often correct), otherwise the correlation is determined
directly from sums (the "direct" method in `~scipy.signal.correlate`).
Returns
-------
Expand Down Expand Up @@ -84,9 +90,14 @@ def template_correlate(observed_spectrum, template_spectrum, lag_units=_KMS,
normalization = 1.

# Correlate
corr = np.correlate(observed_log_spectrum.flux.value,
(template_log_spectrum.flux.value * normalization),
mode='full')
if method.lower() == "fft":
corr = correlate(observed_log_spectrum.flux.value,
(template_log_spectrum.flux.value * normalization),
method="fft")
else:
corr = correlate(observed_log_spectrum.flux.value,
(template_log_spectrum.flux.value * normalization),
method="direct")

# Compute lag
# wave_l is the wavelength array equally spaced in log space.
Expand Down

0 comments on commit fb5bd08

Please sign in to comment.