Skip to content

Commit

Permalink
Merge pull request #2334 from desihub/large_trace_shifts
Browse files Browse the repository at this point in the history
script to fit large trace shifts
  • Loading branch information
sbailey committed Aug 22, 2024
2 parents b0f95ae + d94aa17 commit 592e2e8
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 2 deletions.
17 changes: 17 additions & 0 deletions bin/desi_compute_large_trace_shifts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/usr/bin/env python
#
# See top-level LICENSE.rst file for Copyright information
#
# -*- coding: utf-8 -*-

"""
This script computes the trace shifts for a preprocessed image, using a PSF and a set of known lines
"""

import sys
import desispec.scripts.large_trace_shifts as large_trace_shifts


if __name__ == '__main__':
args = large_trace_shifts.parse()
sys.exit(large_trace_shifts.main(args))
10 changes: 8 additions & 2 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ desispec API
.. automodule:: desispec.joincosmics
:members:

.. automodule:: desispec.large_trace_shifts
:members:

.. automodule:: desispec.linalg
:members:

Expand Down Expand Up @@ -460,7 +463,7 @@ desispec API

.. automodule:: desispec.scripts.createoverride
:members:

.. automodule:: desispec.scripts.daily_processing
:members:

Expand Down Expand Up @@ -506,6 +509,9 @@ desispec API
.. automodule:: desispec.scripts.interpolate_fiber_psf
:members:

.. automodule:: desispec.scripts.large_trace_shifts
:members:

.. automodule:: desispec.scripts.link_calibnight
:members:

Expand Down Expand Up @@ -625,7 +631,7 @@ desispec API

.. automodule:: desispec.scripts.update_exptable
:members:

.. automodule:: desispec.scripts.update_spectra
:members:

Expand Down
128 changes: 128 additions & 0 deletions py/desispec/large_trace_shifts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
desispec.large_trace_shifts
===========================
"""

from __future__ import absolute_import, division

import os
import sys
import numpy as np
from scipy.spatial import cKDTree as KDTree
from scipy.signal import fftconvolve

from desiutil.log import get_logger


def detect_spots_in_image(image) :
'''
Detection of spots in preprocessed arc lamp image
Args:
image : preprocessed arc lamp image (desispec.Image object)
returns:
xc: 1D float numpy array with xccd spot coordinates in the image (CCD column number)
yc: 1D float numpy array with yccd spot coordinates in the image (CCD row number)
'''

log = get_logger()

# set to zero masked pixels
image.ivar *= (image.mask==0)
image.ivar *= (image.ivar>0)
image.pix *= (image.ivar>0)

# convolve with Gaussian kernel
hw = 3
sigma = 1.
x = np.tile(np.arange(-hw,hw+1),(2*hw+1,1))
y = x.T.copy()
kernel = np.exp(-(x**2+y**2)/2/sigma**2)
kernel /= np.sum(kernel)
simg = fftconvolve(image.pix,kernel,mode='same')
sivar = fftconvolve(image.ivar,kernel**2,mode='same')
sivar *= (sivar>0)

log.info("detections")
nsig = 6
detections = (simg*np.sqrt(sivar))>nsig
peaks=np.zeros(simg.shape)
peaks[1:-1,1:-1] = (detections[1:-1,1:-1]>0)\
*(simg[1:-1,1:-1]>simg[2:,1:-1])\
*(simg[1:-1,1:-1]>simg[:-2,1:-1])\
*(simg[1:-1,1:-1]>simg[1:-1,2:])\
*(simg[1:-1,1:-1]>simg[1:-1,:-2])

log.info("peak coordinates")
x=np.tile(np.arange(simg.shape[1]),(simg.shape[0],1))
y=np.tile(np.arange(simg.shape[0]),(simg.shape[1],1)).T
xp=x[peaks>0]
yp=y[peaks>0]

nspots=xp.size
if nspots>1e5 :
message="way too many spots detected: {}. Aborting".format(nspots)
log.error(message)
raise RuntimeError(message)

log.info("refit {} spots centers".format(nspots))
xc=np.zeros(nspots)
yc=np.zeros(nspots)
for p in range(nspots) :
b0=yp[p]-3
e0=yp[p]+4
b1=xp[p]-3
e1=xp[p]+4
spix=np.sum(image.pix[b0:e0,b1:e1])
xc[p]=np.sum(image.pix[b0:e0,b1:e1]*x[b0:e0,b1:e1])/spix
yc[p]=np.sum(image.pix[b0:e0,b1:e1]*y[b0:e0,b1:e1])/spix
log.info("done")

return xc,yc


# copied from desimeter to avoid dependencies


def match_same_system(x1,y1,x2,y2,remove_duplicates=True) :
'''
match two catalogs, assuming the coordinates are in the same coordinate system (no transfo)
Args:
x1 : float numpy array of coordinates along first axis of cartesian coordinate system
y1 : float numpy array of coordinates along second axis in same system
x2 : float numpy array of coordinates along first axis in same system
y2 : float numpy array of coordinates along second axis in same system
returns:
indices_2 : integer numpy array. if ii is a index array for entries in the first catalog,
indices_2[ii] is the index array of best matching entries in the second catalog.
(one should compare x1[ii] with x2[indices_2[ii]])
negative indices_2 indicate unmatched entries
distances : distances between pairs. It can be used to discard bad matches
'''

xy1=np.array([x1,y1]).T
xy2=np.array([x2,y2]).T
tree2 = KDTree(xy2)
distances,indices_2 = tree2.query(xy1,k=1)

if remove_duplicates :
unique_indices_2 = np.unique(indices_2)
n_duplicates = np.sum(indices_2>=0)-np.sum(unique_indices_2>=0)
if n_duplicates > 0 :
for i2 in unique_indices_2 :
jj=np.where(indices_2==i2)[0]
if jj.size>1 :
kk=np.argsort(distances[jj])
indices_2[jj[kk[1:]]] = -1

distances[indices_2<0] = np.inf
return indices_2,distances
133 changes: 133 additions & 0 deletions py/desispec/scripts/large_trace_shifts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""
desispec.scripts.large_trace_shifts
===================================
"""

import os, sys
import argparse
import numpy as np
import matplotlib.pyplot as plt

from desispec.io.xytraceset import read_xytraceset
from desispec.io import read_image
from desiutil.log import get_logger
from desispec.large_trace_shifts import detect_spots_in_image,match_same_system
from desispec.trace_shifts import write_traces_in_psf

def parse(options=None):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="""Find large trace shifts by matching arc lamp spots in preprocessed images.""")
parser.add_argument('--ref-image', type = str, default = None, required=True,
help = 'path of DESI reference preprocessed arc lamps fits image')
parser.add_argument('-i','--image', type = str, default = None, required=True,
help = 'path of DESI preprocessed arc lamps fits image')
parser.add_argument('--ref-psf', type = str, default = None, required=False,
help = 'path of DESI psf fits file corresponding to the reference image')
parser.add_argument('-o','--output-psf', type = str, default = None, required=False,
help = 'path of output shifted psf file')
parser.add_argument('--plot',action='store_true', help="plot spots")

args = parser.parse_args(options)
return args

def main(args=None) :

log= get_logger()

if not isinstance(args, argparse.Namespace):
args = parse(args)


ref_image = read_image(args.ref_image)
xref,yref = detect_spots_in_image(ref_image)

in_image = read_image(args.image)
xin,yin = detect_spots_in_image(in_image)

indices,distances = match_same_system(xref,yref,xin,yin,remove_duplicates=True)

ok=(indices>=0)

rmsdist = 1.4*np.median(distances[ok])
ok &= (distances<5*rmsdist)

nmatch = np.sum(ok)
if nmatch<10 :
message = "too few matches: {}. Aborting.".format(nmatch)
log.error(message)
sys.exit(12)
xref=xref[ok]
yref=yref[ok]
xin=xin[indices[ok]]
yin=yin[indices[ok]]


delta_x = np.median(xin-xref)
delta_y = np.median(yin-yref)
log.info("First delta_x = {:.2f} delta_y = {:.2f}".format(delta_x,delta_y))

distances = (xin-xref-delta_x)**2+(yin-yref-delta_y)**2
rmsdist = 1.4*np.median(distances)
ok = (distances<5*rmsdist)
nmatch = np.sum(ok)
if nmatch<10 :
message = "too few matches: {}. Aborting.".format(nmatch)
log.error(message)
sys.exit(12)

xref=xref[ok]
yref=yref[ok]
xin=xin[ok]
yin=yin[ok]
delta_x = np.median(xin-xref)
delta_y = np.median(yin-yref)
distances = (xin-xref-delta_x)**2+(yin-yref-delta_y)**2
rms_dist = np.sqrt(np.mean(distances**2))
log.info("Refined delta_x = {:.2f} delta_y = {:.2f} rms dist = {:.2f}".format(delta_x,delta_y,rms_dist))


if args.ref_psf is not None :
log.info("Read traceset in {}".format(args.ref_psf))
tset = read_xytraceset(args.ref_psf)
tset.x_vs_wave_traceset._coeff[:,0] += delta_x
tset.y_vs_wave_traceset._coeff[:,0] += delta_y

if args.output_psf is not None :
log.info("Write modified traceset in {}".format(args.output_psf))
write_traces_in_psf(args.ref_psf,args.output_psf,tset)

if args.plot :
if 0 :
plt.figure()
plt.plot(xref,yref,".")
plt.plot(xin,yin,".")

plt.figure()
plt.subplot(221)
plt.plot(xref,xin-xref,".")
plt.axhline(delta_x,linestyle="--")
plt.xlabel("X")
plt.ylabel("dX")
plt.subplot(222)
plt.plot(yref,xin-xref,".")
plt.axhline(delta_x,linestyle="--")
plt.xlabel("Y")
plt.ylabel("dX")
plt.subplot(223)
plt.plot(xref,yin-yref,".")
plt.axhline(delta_y,linestyle="--")
plt.xlabel("X")
plt.ylabel("dY")
plt.subplot(224)
plt.plot(yref,yin-yref,".")
plt.axhline(delta_y,linestyle="--")
plt.xlabel("Y")
plt.ylabel("dY")

plt.figure()
plt.plot(xref,yref,"X",color="C0")
plt.plot(xin,yin,".",color="red",alpha=0.7)
plt.plot(xin-delta_x,yin-delta_y,".",color="C1")

plt.show()

0 comments on commit 592e2e8

Please sign in to comment.