Skip to content

Commit

Permalink
pep8 updates
Browse files Browse the repository at this point in the history
  • Loading branch information
bsunnquist committed Jul 25, 2024
1 parent 99f9a51 commit 301c0a8
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"source": [
"This notebook demonstrates how to remove [wisps](https://jwst-docs.stsci.edu/known-issues-with-jwst-data/nircam-known-issues/nircam-scattered-light-artifacts#NIRCamScatteredLightArtifacts-wispsWisps) from NIRCam imaging data. Wisps are a scattered light feature affecting detectors A3, A4, B3, and B4. For a given filter, wisps appear in the same detector location with only their brightness varying between exposures; therefore, they can be removed from science data by scaling and subtracting a wisp template (i.e. a median combination of all wisp appearances).\n",
"\n",
"Wisp templates used by this notebook can be downloaded from the version3 folder in the [NIRCam wisp template Box folder](https://stsci.box.com/s/1bymvf1lkrqbdn9rnkluzqk30e8o2bne). For this notebook, only the F200W templates are needed.\n",
"Wisp templates used by this notebook are available in the version3 folder in the [NIRCam wisp template Box folder](https://stsci.box.com/s/1bymvf1lkrqbdn9rnkluzqk30e8o2bne). For this notebook, only the F200W templates are needed (i.e. WISP_NRCA3_F200W_CLEAR.fits, WISP_NRCA4_F200W_CLEAR.fits, WISP_NRCB3_F200W_CLEAR.fits and WISP_NRCB4_F200W_CLEAR.fits).\n",
"\n",
"This notebook uses the `subtract_wisp.py` code to scale and subtract the wisps. That code can be used by itself within python, and is preferred if calibrating a large number of files in parallel, but this notebook will be used to demonstrate the various parameters available to optimize wisp removal. For each notebook cell, we'll also show the corresponding command to run the equivalent in python."
]
Expand All @@ -72,13 +72,13 @@
"# pip install jwst==1.14.0\n",
"# pip install astroquery\n",
"\n",
"from subtract_wisp import make_segmap, process_file, process_files, subtract_wisp\n",
"from subtract_wisp import make_segmap, process_file, subtract_wisp\n",
"\n",
"from astropy.io import fits\n",
"from astroquery.mast import Mast, Observations\n",
"import glob\n",
"import matplotlib\n",
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
Expand Down Expand Up @@ -112,12 +112,12 @@
" {\"paramName\": \"observtn\", \"values\": ['6']},\n",
" {\"paramName\": \"exposure\", \"values\": ['00005']},\n",
" {\"paramName\": \"visit\", \"values\": ['004']},\n",
" {\"paramName\": \"detector\", \"values\": ['NRCA3','NRCA4','NRCB3','NRCB4','NRCALONG','NRCBLONG']},\n",
" {\"paramName\": \"detector\", \"values\": ['NRCA3', 'NRCA4', 'NRCB3', 'NRCB4', 'NRCALONG', 'NRCBLONG']},\n",
" {\"paramName\": \"productLevel\", \"values\": ['2b']}]}\n",
"t = Mast().service_request('Mast.Jwst.Filtered.Nircam', params)\n",
"for row in t:\n",
" if '_cal' in row['filename']: # only want cal files\n",
" result = Observations().download_file(row['dataURI'], cache=False)\n"
" result = Observations().download_file(row['dataURI'], cache=False)"
]
},
{
Expand Down Expand Up @@ -155,7 +155,7 @@
"files = glob.glob('*_cal.fits')\n",
"files = [f for f in files if 'long' not in f] # only want shortwave files\n",
"for file in files:\n",
" results = process_file(file, show_plot=True)\n"
" results = process_file(file, show_plot=True)"
]
},
{
Expand Down Expand Up @@ -225,7 +225,7 @@
"outputs": [],
"source": [
"results = process_file('jw01063006004_02101_00005_nrcb4_cal.fits', scale_method='median', poly_degree=0, \n",
" correct_cols=True, show_plot=True)\n"
" correct_cols=True, show_plot=True)"
]
},
{
Expand Down Expand Up @@ -271,7 +271,7 @@
"outputs": [],
"source": [
"results = process_file('jw01063006004_02101_00005_nrcb4_cal.fits', flag_wisp_thresh=0.03, dq_val=1073741824, sub_wisp=False,\n",
" show_plot=True)\n"
" show_plot=True)"
]
},
{
Expand All @@ -292,8 +292,8 @@
"# Check that the data quality array in this file was updated appropriately\n",
"\n",
"dq = fits.getdata('jw01063006004_02101_00005_nrcb4_cal_wisp.fits', 'DQ')\n",
"dq = (dq&1073741824!=0).astype(int) # only want to see pixels flagged as OTHER_BAD_PIXEL, i.e. the dq_val used above\n",
"plt.imshow(dq, cmap='gray', origin='lower', vmin=0, vmax=0.1)\n"
"dq = (dq & 1073741824 != 0).astype(int) # only want to see pixels flagged as OTHER_BAD_PIXEL, i.e. the dq_val used above\n",
"plt.imshow(dq, cmap='gray', origin='lower', vmin=0, vmax=0.1)"
]
},
{
Expand Down Expand Up @@ -325,7 +325,7 @@
"outputs": [],
"source": [
"results = process_file('jw01063006004_02101_00005_nrcb4_cal.fits', seg_from_lw=False, sigma=1.5, save_segmap=True, \n",
" min_wisp=0.01, factor_min=0.5, factor_max=1.5, factor_step=0.05, show_plot=True)\n"
" min_wisp=0.01, factor_min=0.5, factor_max=1.5, factor_step=0.05, show_plot=True)"
]
},
{
Expand Down Expand Up @@ -354,9 +354,9 @@
"data = fits.getdata('jw01063006004_02101_00005_nrcb4_cal.fits')\n",
"segmap = fits.getdata('jw01063006004_02101_00005_nrcb4_cal_seg.fits')\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(20,10))\n",
"fig, axes = plt.subplots(1, 2, figsize=(20, 10))\n",
"axes[0].imshow(data, origin='lower', cmap='gray', vmin=0.18, vmax=0.3)\n",
"axes[1].imshow(segmap, origin='lower', cmap='gray', vmin=0, vmax=0.1)\n"
"axes[1].imshow(segmap, origin='lower', cmap='gray', vmin=0, vmax=0.1)"
]
},
{
Expand Down Expand Up @@ -389,7 +389,7 @@
"wisp_data = fits.getdata('WISP_NRCB4_F200W_CLEAR.fits') * 1.05\n",
"\n",
"# Process the file with the custom wisp template\n",
"results = subtract_wisp('jw01063006004_02101_00005_nrcb4_cal.fits', wisp_data, scale_wisp=False, show_plot=True)\n"
"results = subtract_wisp('jw01063006004_02101_00005_nrcb4_cal.fits', wisp_data, scale_wisp=False, show_plot=True)"
]
},
{
Expand Down Expand Up @@ -426,7 +426,7 @@
"segmap_data = make_segmap(f, sigma=1.0, npixels=8, dilate_segmap=10)\n",
"\n",
"# Scale and subtract the wisp template\n",
"results = subtract_wisp(f, wisp_data=wisp_data, segmap_data=segmap_data, show_plot=True)\n"
"results = subtract_wisp(f, wisp_data=wisp_data, segmap_data=segmap_data, show_plot=True)"
]
},
{
Expand Down
67 changes: 39 additions & 28 deletions notebooks/NIRCam/NIRCam_wisp_subtraction/subtract_wisp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@
import multiprocessing
import os
import warnings
warnings.filterwarnings('ignore', message="Input data contains invalid values*") # nan values expected throughout code
warnings.filterwarnings('ignore', message="All-NaN slice encountered*")
warnings.filterwarnings('ignore', message="'obsfix' made the change*") # from astropy wcs during lw segmap blotting
warnings.filterwarnings('ignore', message="'datfix' made the change*") # from astropy wcs during lw segmap blotting

from astropy.convolution import convolve, Gaussian2DKernel
from astropy.io import fits
Expand All @@ -55,8 +51,15 @@
from photutils.segmentation import detect_sources, detect_threshold
from scipy.ndimage import binary_dilation, generate_binary_structure

warnings.filterwarnings('ignore', message="Input data contains invalid values*") # nan values expected throughout code
warnings.filterwarnings('ignore', message="All-NaN slice encountered*")
warnings.filterwarnings('ignore', message="'obsfix' made the change*") # from astropy wcs during lw segmap blotting
warnings.filterwarnings('ignore', message="'datfix' made the change*") # from astropy wcs during lw segmap blotting


# -----------------------------------------------------------------------------


def make_segmap(f, seg_from_lw=True, sigma=0.8, npixels=10, dilate_segmap=5, save_segmap=False):
"""
Make a segmentation map for the input file.
Expand Down Expand Up @@ -116,15 +119,15 @@ def make_segmap(f, seg_from_lw=True, sigma=0.8, npixels=10, dilate_segmap=5, sav
# Make the segmentation map
threshold = detect_threshold(data, sigma)
g = Gaussian2DKernel(x_stddev=3)
data_conv = convolve(data, g, mask=dq&1!=0) # Smooth input image before detecting sources
seg = detect_sources(data_conv, threshold, npixels=npixels, mask=dq&1!=0) # avoid bad pixels as sources
data_conv = convolve(data, g, mask=dq & 1 != 0) # Smooth input image before detecting sources
seg = detect_sources(data_conv, threshold, npixels=npixels, mask=dq & 1 != 0) # avoid bad pixels as sources
segmap_data = seg.data
segmap_data[segmap_data!=0] = 1
segmap_data[segmap_data != 0] = 1

# Dilate the segmap outwards
if dilate_segmap != 0:
segmap_data = binary_dilation(segmap_data, iterations=dilate_segmap, structure=generate_binary_structure(2, 2))
segmap_data[segmap_data!=0] = 1
segmap_data[segmap_data != 0] = 1

# Blot LW segmap back onto SW detector space
if (seg_from_lw) & ('long' not in detector):
Expand All @@ -135,9 +138,9 @@ def make_segmap(f, seg_from_lw=True, sigma=0.8, npixels=10, dilate_segmap=5, sav
wcs = WCS(fits.getheader(f_sw, 'SCI')) # sw cal wcs
coords = wcs.world_to_pixel(sky_coords)
for i in np.arange(len(coords[0])):
y,x = int(coords[1][i]), int(coords[0][i])
if (y<2048) & (x<2048) & (y>=0) & (x>=0):
segmap_tmp[y,x] = 1
y, x = int(coords[1][i]), int(coords[0][i])
if (y < 2048) & (x < 2048) & (y >= 0) & (x >= 0):
segmap_tmp[y, x] = 1
# Dilate to compensate for y,x rounding due to different lw/sw pixel scales
segmap_data = binary_dilation(segmap_tmp, iterations=1, structure=generate_binary_structure(2, 2)).astype(int)

Expand All @@ -148,8 +151,10 @@ def make_segmap(f, seg_from_lw=True, sigma=0.8, npixels=10, dilate_segmap=5, sav

return segmap_data


# -----------------------------------------------------------------------------


def process_file(f, wisp_dir='./', create_segmap=True, seg_from_lw=True, sigma=0.8, npixels=10, dilate_segmap=5,
save_segmap=False, sub_wisp=True, gauss_smooth_wisp=False, gauss_stddev=3.0, scale_wisp=True,
scale_method='mad', poly_degree=5, factor_min=0.0, factor_max=2.0, factor_step=0.01, min_wisp=None,
Expand Down Expand Up @@ -186,17 +191,18 @@ def process_file(f, wisp_dir='./', create_segmap=True, seg_from_lw=True, sigma=0
segmap_data = np.zeros(wisp_data.shape).astype(int)

# Scale and subtract wisp template
results = subtract_wisp(f, wisp_data=wisp_data, segmap_data=segmap_data, sub_wisp=sub_wisp,
gauss_smooth_wisp=gauss_smooth_wisp, gauss_stddev=gauss_stddev,
scale_wisp=scale_wisp, scale_method=scale_method, poly_degree=poly_degree,
factor_min=factor_min, factor_max=factor_max, factor_step=factor_step,
min_wisp=min_wisp, flag_wisp_thresh=flag_wisp_thresh, dq_val=dq_val,
correct_rows=correct_rows, correct_cols=correct_cols, save_data=save_data,
save_model=save_model, plot=plot, show_plot=show_plot, suffix=suffix)
_ = subtract_wisp(f, wisp_data=wisp_data, segmap_data=segmap_data, sub_wisp=sub_wisp,
gauss_smooth_wisp=gauss_smooth_wisp, gauss_stddev=gauss_stddev,
scale_wisp=scale_wisp, scale_method=scale_method, poly_degree=poly_degree,
factor_min=factor_min, factor_max=factor_max, factor_step=factor_step,
min_wisp=min_wisp, flag_wisp_thresh=flag_wisp_thresh, dq_val=dq_val,
correct_rows=correct_rows, correct_cols=correct_cols, save_data=save_data,
save_model=save_model, plot=plot, show_plot=show_plot, suffix=suffix)
print('Processing complete for {}'.format(f))

# -----------------------------------------------------------------------------


def process_files(files, nproc=6, **kwargs):
""""Wrapper around the process_file() function to allow for multiprocessing."""

Expand All @@ -207,12 +213,14 @@ def process_files(files, nproc=6, **kwargs):
# Proess the files
process_file_partial = partial(process_file, **kwargs)
p = multiprocessing.Pool(nproc)
results = p.map(process_file_partial, files)
_ = p.map(process_file_partial, files)
p.close()
p.join()


# -----------------------------------------------------------------------------


def subtract_wisp(f, wisp_data, segmap_data=None, sub_wisp=True, gauss_smooth_wisp=False, gauss_stddev=3.0, scale_wisp=True,
scale_method='mad', poly_degree=5, factor_min=0.0, factor_max=2.0, factor_step=0.01, min_wisp=None,
flag_wisp_thresh=None, dq_val=1, correct_rows=True, correct_cols=False, save_data=True, save_model=True, plot=True,
Expand Down Expand Up @@ -369,13 +377,13 @@ def subtract_wisp(f, wisp_data, segmap_data=None, sub_wisp=True, gauss_smooth_wi
# pixels in the wisp region are unmasked.
data_masked = np.copy(data)
wisp_data_masked = np.copy(wisp_data)
data_masked[(dq&1!=0) | (segmap_data!=0) | (wisp_mask==0)] = np.nan
wisp_data_masked[(dq&1!=0) | (segmap_data!=0) |(wisp_mask==0)] = np.nan
data_masked[(dq & 1 != 0) | (segmap_data != 0) | (wisp_mask == 0)] = np.nan
wisp_data_masked[(dq & 1 != 0) | (segmap_data != 0) |(wisp_mask == 0)] = np.nan

# Make a version of the original data where only good pixels outside
# the wisp region are unmasked.
data_masked_ff = np.copy(data)
data_masked_ff[(dq&1!=0) | (segmap_data!=0) | (wisp_mask!=0)] = np.nan
data_masked_ff[(dq & 1 != 0) | (segmap_data != 0) | (wisp_mask != 0)] = np.nan

# Correct median-collapsed row/column offsets, representing the 1/f residuals
# and odd-even column residuals and amp offsets, respectively.
Expand All @@ -388,8 +396,7 @@ def subtract_wisp(f, wisp_data, segmap_data=None, sub_wisp=True, gauss_smooth_wi
collapsed_cols = np.nanmedian(data_masked_ff - med, axis=0)
else:
collapsed_cols = np.zeros(2048)
correction_image = np.tile(collapsed_cols, (2048, 1)) + \
np.swapaxes(np.tile(collapsed_rows, (2048, 1)), 0, 1)
correction_image = np.tile(collapsed_cols, (2048, 1)) + np.swapaxes(np.tile(collapsed_rows, (2048, 1)), 0, 1)
data_masked = data_masked - correction_image
data_masked_ff = data_masked_ff - correction_image
med = np.nanmedian(data_masked_ff)
Expand Down Expand Up @@ -422,12 +429,12 @@ def subtract_wisp(f, wisp_data, segmap_data=None, sub_wisp=True, gauss_smooth_wi

# Only subtract wisp values above the specified threshold
if min_wisp is not None:
wisp_model[wisp_model<min_wisp] = 0
wisp_model[wisp_model < min_wisp] = 0

# Flag wisp values above the specified thereshold in DQ array
if flag_wisp_thresh is not None:
new_dq = np.copy(dq)
new_dq[(dq&dq_val==0) & (wisp_model>flag_wisp_thresh)] += dq_val
new_dq[(dq & dq_val == 0) & (wisp_model > flag_wisp_thresh)] += dq_val
else:
new_dq = dq

Expand All @@ -448,7 +455,7 @@ def subtract_wisp(f, wisp_data, segmap_data=None, sub_wisp=True, gauss_smooth_wi

# Make diagnostic plots
if plot:
fig, axes = plt.subplots(1, 4, figsize=(60,10))
fig, axes = plt.subplots(1, 4, figsize=(60, 10))
for ax in axes:
ax.tick_params(axis='both', which='major', labelsize=20)
# Plot original image
Expand Down Expand Up @@ -485,9 +492,11 @@ def subtract_wisp(f, wisp_data, segmap_data=None, sub_wisp=True, gauss_smooth_wi

return new_data, wisp_model, factors, residuals, factor


# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------


def parse_args():
"""
Parses command line arguments.
Expand Down Expand Up @@ -546,11 +555,13 @@ def parse_args():

return args


# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------


if __name__ == '__main__':

# Get the command line arguments
args = parse_args()

Expand Down

0 comments on commit 301c0a8

Please sign in to comment.