From 3d73827c3a4628606869a206e3d74bac8b12a1ec Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Fri, 24 Nov 2023 19:02:25 -0500 Subject: [PATCH] Use dot access instead of dict access for dqflags --- romancal/dq_init/dq_init_step.py | 7 ++-- romancal/dq_init/tests/test_dq_init.py | 41 ++++++++----------- romancal/flatfield/flat_field.py | 10 ++--- romancal/jump/jump_step.py | 12 +++--- romancal/lib/basic_utils.py | 9 ++-- romancal/lib/psf.py | 4 +- romancal/linearity/linearity_step.py | 4 +- .../outlier_detection/outlier_detection.py | 13 +++--- romancal/pipeline/exposure_pipeline.py | 6 +-- romancal/ramp_fitting/ramp_fit_step.py | 23 +++++------ .../ramp_fitting/tests/test_ramp_fit_cas22.py | 12 +----- .../ramp_fitting/tests/test_ramp_fit_ols.py | 17 ++------ romancal/regtest/test_wfi_pipeline.py | 8 ++-- romancal/resample/resample_utils.py | 4 +- romancal/saturation/saturation.py | 4 +- romancal/saturation/tests/test_saturation.py | 41 ++++++++----------- romancal/skymatch/tests/test_skymatch.py | 15 +++---- .../source_detection/source_detection_step.py | 7 ++-- 18 files changed, 95 insertions(+), 142 deletions(-) diff --git a/romancal/dq_init/dq_init_step.py b/romancal/dq_init/dq_init_step.py index 26d961e4b..9d9da1767 100644 --- a/romancal/dq_init/dq_init_step.py +++ b/romancal/dq_init/dq_init_step.py @@ -2,8 +2,9 @@ import numpy as np import roman_datamodels as rdm -from roman_datamodels import dqflags, maker_utils +from roman_datamodels import maker_utils from roman_datamodels.datamodels import RampModel +from roman_datamodels.dqflags import pixel from romancal.dq_init import dq_initialization from romancal.stpipe import RomanStep @@ -74,9 +75,7 @@ def process(self, input): x_start = input_model.meta.guidestar.gw_window_xstart x_end = input_model.meta.guidestar.gw_window_xsize + x_start # set pixeldq array to GW_AFFECTED_DATA (2**4) for the given range - output_model.pixeldq[int(x_start) : int(x_end), :] = dqflags.pixel[ - "GW_AFFECTED_DATA" - ] + output_model.pixeldq[int(x_start) : int(x_end), :] = pixel.GW_AFFECTED_DATA self.log.info( f"Flagging rows from: {x_start} to {x_end} as affected by guide window read" ) diff --git a/romancal/dq_init/tests/test_dq_init.py b/romancal/dq_init/tests/test_dq_init.py index 9c403c63e..cd6822d19 100644 --- a/romancal/dq_init/tests/test_dq_init.py +++ b/romancal/dq_init/tests/test_dq_init.py @@ -1,8 +1,9 @@ import numpy as np import pytest from astropy import units as u -from roman_datamodels import dqflags, maker_utils, stnode +from roman_datamodels import maker_utils, stnode from roman_datamodels.datamodels import MaskRefModel, ScienceRawModel +from roman_datamodels.dqflags import pixel from romancal.dq_init import DQInitStep from romancal.dq_init.dq_initialization import do_dqinit @@ -55,18 +56,16 @@ def test_dq_im(xstart, ystart, xsize, ysize, ngroups, instrument, exp_type): # assert that the pixels read back in match the mapping from ref data to # science data - assert dqdata[100, 100] == dqflags.pixel["SATURATED"] - assert dqdata[200, 100] == dqflags.pixel["JUMP_DET"] - assert dqdata[300, 100] == dqflags.pixel["DROPOUT"] - assert dqdata[400, 100] == dqflags.pixel["PERSISTENCE"] - assert dqdata[500, 100] == dqflags.pixel["DO_NOT_USE"] - assert dqdata[600, 100] == dqflags.pixel["GW_AFFECTED_DATA"] - assert dqdata[100, 200] == dqflags.pixel["SATURATED"] + dqflags.pixel["DO_NOT_USE"] - assert dqdata[200, 200] == dqflags.pixel["JUMP_DET"] + dqflags.pixel["DO_NOT_USE"] - assert dqdata[300, 200] == dqflags.pixel["DROPOUT"] + dqflags.pixel["DO_NOT_USE"] - assert ( - dqdata[400, 200] == dqflags.pixel["PERSISTENCE"] + dqflags.pixel["DO_NOT_USE"] - ) + assert dqdata[100, 100] == pixel.SATURATED + assert dqdata[200, 100] == pixel.JUMP_DET + assert dqdata[300, 100] == pixel.DROPOUT + assert dqdata[400, 100] == pixel.PERSISTENCE + assert dqdata[500, 100] == pixel.DO_NOT_USE + assert dqdata[600, 100] == pixel.GW_AFFECTED_DATA + assert dqdata[100, 200] == pixel.SATURATED + pixel.DO_NOT_USE + assert dqdata[200, 200] == pixel.JUMP_DET + pixel.DO_NOT_USE + assert dqdata[300, 200] == pixel.DROPOUT + pixel.DO_NOT_USE + assert dqdata[400, 200] == pixel.PERSISTENCE + pixel.DO_NOT_USE def test_groupdq(): @@ -171,15 +170,9 @@ def test_dq_add1_groupdq(): # test if pixels in pixeldq were incremented in value by 1 # check that previous dq flag is added to mask value - assert ( - outfile.pixeldq[505, 505] - == dqflags.pixel["JUMP_DET"] + dqflags.pixel["DO_NOT_USE"] - ) + assert outfile.pixeldq[505, 505] == pixel.JUMP_DET + pixel.DO_NOT_USE # check two flags propagate correctly - assert ( - outfile.pixeldq[400, 500] - == dqflags.pixel["SATURATED"] + dqflags.pixel["DO_NOT_USE"] - ) + assert outfile.pixeldq[400, 500] == pixel.SATURATED + pixel.DO_NOT_USE @pytest.mark.parametrize( @@ -303,7 +296,7 @@ def test_dqinit_resultantdq(instrument, exptype): wfi_sci_raw.meta["guidestar"]["gw_window_xstart"] = 1012 wfi_sci_raw.meta["guidestar"]["gw_window_xsize"] = 16 wfi_sci_raw.meta.exposure.type = exptype - wfi_sci_raw.resultantdq[1, 12, 12] = dqflags.pixel["DROPOUT"] + wfi_sci_raw.resultantdq[1, 12, 12] = pixel["DROPOUT"] wfi_sci_raw.data = u.Quantity( np.ones(shape, dtype=np.uint16), u.DN, dtype=np.uint16 ) @@ -330,8 +323,8 @@ def test_dqinit_resultantdq(instrument, exptype): # check to see the resultantdq is the correct shape assert wfi_sci_raw_model.resultantdq.shape == shape # check to see the resultantdq & groupdq have the correct value - assert wfi_sci_raw_model.resultantdq[1, 12, 12] == dqflags.pixel["DROPOUT"] - assert result.groupdq[1, 12, 12] == dqflags.pixel["DROPOUT"] + assert wfi_sci_raw_model.resultantdq[1, 12, 12] == pixel["DROPOUT"] + assert result.groupdq[1, 12, 12] == pixel["DROPOUT"] @pytest.mark.parametrize( diff --git a/romancal/flatfield/flat_field.py b/romancal/flatfield/flat_field.py index ee3c47d34..60ced83bb 100644 --- a/romancal/flatfield/flat_field.py +++ b/romancal/flatfield/flat_field.py @@ -6,7 +6,7 @@ import numpy as np from astropy import units as u -from roman_datamodels import dqflags +from roman_datamodels.dqflags import pixel log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) @@ -90,17 +90,15 @@ def apply_flat_field(science, flat): # Find pixels in the flat that have a value of NaN and set # their DQ to NO_FLAT_FIELD flat_nan = np.isnan(flat_data) - flat_dq[flat_nan] = np.bitwise_or(flat_dq[flat_nan], dqflags.pixel["NO_FLAT_FIELD"]) + flat_dq[flat_nan] = np.bitwise_or(flat_dq[flat_nan], pixel.NO_FLAT_FIELD) # Find pixels in the flat that have a value of zero, and set # their DQ to NO_FLAT_FIELD flat_zero = np.where(flat_data == 0.0) - flat_dq[flat_zero] = np.bitwise_or( - flat_dq[flat_zero], dqflags.pixel["NO_FLAT_FIELD"] - ) + flat_dq[flat_zero] = np.bitwise_or(flat_dq[flat_zero], pixel.NO_FLAT_FIELD) # Find all pixels in the flat that have a DQ value of NO_FLAT_FIELD - flat_bad = np.bitwise_and(flat_dq, dqflags.pixel["NO_FLAT_FIELD"]) + flat_bad = np.bitwise_and(flat_dq, pixel.NO_FLAT_FIELD) # Reset the flat value of all bad pixels to 1.0, so that no # correction is made diff --git a/romancal/jump/jump_step.py b/romancal/jump/jump_step.py index 0407acd5a..31d864064 100644 --- a/romancal/jump/jump_step.py +++ b/romancal/jump/jump_step.py @@ -7,7 +7,7 @@ import numpy as np from roman_datamodels import datamodels as rdd -from roman_datamodels import dqflags +from roman_datamodels.dqflags import group, pixel from stcal.jump.jump import detect_jumps from romancal.stpipe import RomanStep @@ -118,11 +118,11 @@ def process(self, input): # separate PR dqflags_d = {} # Dict of DQ flags dqflags_d = { - "GOOD": dqflags.group["GOOD"], - "DO_NOT_USE": dqflags.group["DO_NOT_USE"], - "SATURATED": dqflags.group["SATURATED"], - "JUMP_DET": dqflags.group["JUMP_DET"], - "NO_GAIN_VALUE": dqflags.pixel["NO_GAIN_VALUE"], + "GOOD": group.GOOD, + "DO_NOT_USE": group.DO_NOT_USE, + "SATURATED": group.SATURATED, + "JUMP_DET": group.JUMP_DET, + "NO_GAIN_VALUE": pixel.NO_GAIN_VALUE, } gdq, pdq, *_ = detect_jumps( diff --git a/romancal/lib/basic_utils.py b/romancal/lib/basic_utils.py index 579aab5d2..ad64debec 100644 --- a/romancal/lib/basic_utils.py +++ b/romancal/lib/basic_utils.py @@ -1,11 +1,8 @@ """General utility objects""" import numpy as np -from roman_datamodels import dqflags from roman_datamodels.datamodels import AssociationsModel - -SATURATEDPIX = dqflags.pixel["SATURATED"] -SATURATEDGRP = dqflags.group["SATURATED"] +from roman_datamodels.dqflags import group, pixel def bytes2human(n): @@ -48,9 +45,9 @@ def is_fully_saturated(model): Check to see if all data pixels are flagged as saturated. """ - if np.all(np.bitwise_and(model.groupdq, SATURATEDGRP) == SATURATEDGRP): + if np.all(np.bitwise_and(model.groupdq, group.SATURATED) == group.SATURATED): return True - elif np.all(np.bitwise_and(model.pixeldq, SATURATEDPIX) == SATURATEDPIX): + elif np.all(np.bitwise_and(model.pixeldq, pixel.SATURATED) == pixel.SATURATED): return True return False diff --git a/romancal/lib/psf.py b/romancal/lib/psf.py index b91e2da8d..d4ba30630 100644 --- a/romancal/lib/psf.py +++ b/romancal/lib/psf.py @@ -20,7 +20,7 @@ SourceGrouper, ) from roman_datamodels.datamodels import ImageModel -from roman_datamodels.dqflags import pixel as roman_dq_flag_enum +from roman_datamodels.dqflags import pixel from webbpsf import conf, gridded_library, restart_logging __all__ = [ @@ -388,7 +388,7 @@ def dq_to_boolean_mask(image_model_or_dq, ignore_flags=0, flag_map_name="ROMAN_D dq = image_model_or_dq # add the Roman DQ flags to the astropy bitmask registry: - dq_flag_map = {dq.name: dq.value for dq in roman_dq_flag_enum if dq.name != "GOOD"} + dq_flag_map = {dq.name: dq.value for dq in pixel if dq.name != "GOOD"} bitmask.extend_bit_flag_map(flag_map_name, **dq_flag_map) # convert the bitmask to a boolean mask: diff --git a/romancal/linearity/linearity_step.py b/romancal/linearity/linearity_step.py index ade510d9c..6bd083281 100644 --- a/romancal/linearity/linearity_step.py +++ b/romancal/linearity/linearity_step.py @@ -5,7 +5,7 @@ import numpy as np from astropy import units as u from roman_datamodels import datamodels as rdd -from roman_datamodels import dqflags +from roman_datamodels.dqflags import pixel from stcal.linearity.linearity import linearity_correction from romancal.stpipe import RomanStep @@ -53,7 +53,7 @@ def process(self, input): # The third return value is the procesed zero frame which # Roman does not use. new_data, new_pdq, _ = linearity_correction( - input_model.data.value, gdq, pdq, lin_coeffs, lin_dq, dqflags.pixel + input_model.data.value, gdq, pdq, lin_coeffs, lin_dq, pixel ) input_model.data = u.Quantity( diff --git a/romancal/outlier_detection/outlier_detection.py b/romancal/outlier_detection/outlier_detection.py index 7ba76d7de..e84bca95f 100644 --- a/romancal/outlier_detection/outlier_detection.py +++ b/romancal/outlier_detection/outlier_detection.py @@ -9,7 +9,7 @@ from astropy.units import Quantity from drizzle.cdrizzle import tblot from roman_datamodels import datamodels as rdm -from roman_datamodels import dqflags +from roman_datamodels.dqflags import pixel from scipy import ndimage from romancal.datamodels import ModelContainer @@ -21,9 +21,6 @@ log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -DO_NOT_USE = dqflags.pixel["DO_NOT_USE"] -OUTLIER = dqflags.pixel["OUTLIER"] - __all__ = ["OutlierDetection", "flag_cr", "abs_deriv"] @@ -375,13 +372,15 @@ def flag_cr( cr_mask = np.greater(diff_noise.value, snr1 * err_data.value) # Count existing DO_NOT_USE pixels - count_existing = np.count_nonzero(sci_image.dq & DO_NOT_USE) + count_existing = np.count_nonzero(sci_image.dq & pixel.DO_NOT_USE) # Update the DQ array in the input image. - sci_image.dq = np.bitwise_or(sci_image.dq, cr_mask * (DO_NOT_USE | OUTLIER)) + sci_image.dq = np.bitwise_or( + sci_image.dq, cr_mask * (pixel.DO_NOT_USE | pixel.OUTLIER) + ) # Report number (and percent) of new DO_NOT_USE pixels found - count_outlier = np.count_nonzero(sci_image.dq & DO_NOT_USE) + count_outlier = np.count_nonzero(sci_image.dq & pixel.DO_NOT_USE) count_added = count_outlier - count_existing percent_cr = count_added / (sci_image.shape[0] * sci_image.shape[1]) * 100 log.info(f"New pixels flagged as outliers: {count_added} ({percent_cr:.2f}%)") diff --git a/romancal/pipeline/exposure_pipeline.py b/romancal/pipeline/exposure_pipeline.py index 1bf8a789d..fa18240bd 100644 --- a/romancal/pipeline/exposure_pipeline.py +++ b/romancal/pipeline/exposure_pipeline.py @@ -4,7 +4,7 @@ import numpy as np from roman_datamodels import datamodels as rdm -from roman_datamodels import dqflags +from roman_datamodels.dqflags import group import romancal.datamodels.filetype as filetype @@ -236,9 +236,7 @@ def create_fully_saturated_zeroed_image(self, input_model): input_model, ( np.zeros(input_model.data.shape[1:], dtype=input_model.data.dtype), - input_model.pixeldq - | input_model.groupdq[0] - | dqflags.group["SATURATED"], + input_model.pixeldq | input_model.groupdq[0] | group.SATURATED, np.zeros(input_model.err.shape[1:], dtype=input_model.err.dtype), np.zeros(input_model.err.shape[1:], dtype=input_model.err.dtype), np.zeros(input_model.err.shape[1:], dtype=input_model.err.dtype), diff --git a/romancal/ramp_fitting/ramp_fit_step.py b/romancal/ramp_fitting/ramp_fit_step.py index cd3fe7d90..bc8fb214e 100644 --- a/romancal/ramp_fitting/ramp_fit_step.py +++ b/romancal/ramp_fitting/ramp_fit_step.py @@ -5,8 +5,9 @@ import numpy as np from astropy import units as u from roman_datamodels import datamodels as rdd -from roman_datamodels import dqflags, maker_utils +from roman_datamodels import maker_utils from roman_datamodels import stnode as rds +from roman_datamodels.dqflags import group, pixel from stcal.ramp_fitting import ols_cas22_fit, ramp_fit from stcal.ramp_fitting.ols_cas22 import Parameter, Variance @@ -105,7 +106,7 @@ def ols(self, input_model, readnoise_model, gain_model): self.algorithm, self.weighting, max_cores, - dqflags.pixel, + pixel, ) if image_info is not None: @@ -149,9 +150,7 @@ def ols(self, input_model, readnoise_model, gain_model): # Image info order is: data, dq, var_poisson, var_rnoise, err image_info = ( np.zeros(input_model.data.shape[2:], dtype=input_model.data.dtype), - input_model.pixeldq - | input_model.groupdq[0][0] - | dqflags.group["SATURATED"], + input_model.pixeldq | input_model.groupdq[0][0] | group.SATURATED, np.zeros(input_model.err.shape[2:], dtype=input_model.err.dtype), np.zeros(input_model.err.shape[2:], dtype=input_model.err.dtype), np.zeros(input_model.err.shape[2:], dtype=input_model.err.dtype), @@ -396,17 +395,17 @@ def get_pixeldq_flags(groupdq, pixeldq, slopes, err, gain): """ outpixeldq = pixeldq.copy() # jump flagging - m = np.any(groupdq & dqflags.group["JUMP_DET"], axis=0) - outpixeldq |= (m * dqflags.pixel["JUMP_DET"]).astype(np.uint32) + m = np.any(groupdq & group.JUMP_DET, axis=0) + outpixeldq |= (m * pixel.JUMP_DET).astype(np.uint32) # all saturated flagging - m = np.all(groupdq & dqflags.group["SATURATED"], axis=0) - outpixeldq |= (m * dqflags.pixel["SATURATED"]).astype(np.uint32) + m = np.all(groupdq & group.SATURATED, axis=0) + outpixeldq |= (m * pixel.SATURATED).astype(np.uint32) # all either saturated or do not use or NaN slope flagging - satordnu = dqflags.group["SATURATED"] | dqflags.group["DO_NOT_USE"] + satordnu = group.SATURATED | group.DO_NOT_USE m = np.all(groupdq & satordnu, axis=0) m |= ~np.isfinite(slopes) | (err <= 0) - outpixeldq |= (m * dqflags.pixel["DO_NOT_USE"]).astype(np.uint32) + outpixeldq |= (m * pixel.DO_NOT_USE).astype(np.uint32) m = (gain < 0) | ~np.isfinite(gain) - outpixeldq |= (m * dqflags.pixel["NO_GAIN_VALUE"]).astype(np.uint32) + outpixeldq |= (m * pixel.NO_GAIN_VALUE).astype(np.uint32) return outpixeldq diff --git a/romancal/ramp_fitting/tests/test_ramp_fit_cas22.py b/romancal/ramp_fitting/tests/test_ramp_fit_cas22.py index 73d940eba..c9a69bae1 100644 --- a/romancal/ramp_fitting/tests/test_ramp_fit_cas22.py +++ b/romancal/ramp_fitting/tests/test_ramp_fit_cas22.py @@ -4,7 +4,7 @@ import pytest from astropy import units as u from astropy.time import Time -from roman_datamodels import dqflags, maker_utils +from roman_datamodels import maker_utils from roman_datamodels.datamodels import GainRefModel, RampModel, ReadnoiseRefModel from romancal.ramp_fitting import RampFitStep @@ -15,16 +15,6 @@ # Used to deconstruct the MultiAccum tables into integration times. ROMAN_READ_TIME = 3.04 -DO_NOT_USE = dqflags.group["DO_NOT_USE"] -JUMP_DET = dqflags.group["JUMP_DET"] -SATURATED = dqflags.group["SATURATED"] - -dqflags = { - "DO_NOT_USE": 1, - "SATURATED": 2, - "JUMP_DET": 4, -} - # Basic resultant # # The read pattern is `[[1], [2], [3], [4]]` diff --git a/romancal/ramp_fitting/tests/test_ramp_fit_ols.py b/romancal/ramp_fitting/tests/test_ramp_fit_ols.py index 858899e35..19f2e3447 100644 --- a/romancal/ramp_fitting/tests/test_ramp_fit_ols.py +++ b/romancal/ramp_fitting/tests/test_ramp_fit_ols.py @@ -2,28 +2,19 @@ import pytest from astropy import units as u from astropy.time import Time -from roman_datamodels import dqflags, maker_utils +from roman_datamodels import maker_utils from roman_datamodels.datamodels import ( GainRefModel, ImageModel, RampModel, ReadnoiseRefModel, ) +from roman_datamodels.dqflags import group from romancal.ramp_fitting import RampFitStep MAXIMUM_CORES = ["none", "quarter", "half", "all"] -DO_NOT_USE = dqflags.group["DO_NOT_USE"] -JUMP_DET = dqflags.group["JUMP_DET"] -SATURATED = dqflags.group["SATURATED"] - -dqflags = { - "DO_NOT_USE": 1, - "SATURATED": 2, - "JUMP_DET": 4, -} - def test_ols_multicore_ramp_fit_match(make_data): """Test various core amount calculation""" @@ -121,7 +112,7 @@ def test_ols_saturated_ramp_fit(max_cores, make_data): model, override_gain, override_readnoise = make_data # Set saturated flag - model.groupdq = model.groupdq | SATURATED + model.groupdq = model.groupdq | group.SATURATED # Run ramp fit step out_model = RampFitStep.call( @@ -139,7 +130,7 @@ def test_ols_saturated_ramp_fit(max_cores, make_data): np.testing.assert_array_equal(out_model.var_rnoise.value, 0) # Test that all pixels are flagged saturated - assert np.all(np.bitwise_and(out_model.dq, SATURATED) == SATURATED) + assert np.all(np.bitwise_and(out_model.dq, group.SATURATED) == group.SATURATED) # Test that original ramp parameters preserved np.testing.assert_allclose(out_model.amp33, model.amp33, 1e-6) diff --git a/romancal/regtest/test_wfi_pipeline.py b/romancal/regtest/test_wfi_pipeline.py index 29af5fc86..f8c6eef14 100644 --- a/romancal/regtest/test_wfi_pipeline.py +++ b/romancal/regtest/test_wfi_pipeline.py @@ -104,9 +104,9 @@ def test_level2_image_processing_pipeline(rtdata, ignore_asdf_paths): pipeline.log.info( "DMS361: Testing that jump detection detected jumps in uneven ramp in " "Level 2 image output......." - + passfail(uneven & np.any(model.dq & pixel["JUMP_DET"])) + + passfail(uneven & np.any(model.dq & pixel.JUMP_DET)) ) - assert uneven & np.any(model.dq & pixel["JUMP_DET"]) + assert uneven & np.any(model.dq & pixel.JUMP_DET) pipeline.log.info( "Status of the step: linearity " + str(model.meta.cal_step.linearity) @@ -347,9 +347,9 @@ def test_level2_grism_processing_pipeline(rtdata, ignore_asdf_paths): pipeline.log.info( "DMS365: Testing that jump detection detected jumps in uneven ramp in " "Level 2 image output......." - + passfail(uneven & np.any(model.dq & pixel["JUMP_DET"])) + + passfail(uneven & np.any(model.dq & pixel.JUMP_DET)) ) - assert uneven & np.any(model.dq & pixel["JUMP_DET"]) + assert uneven & np.any(model.dq & pixel.JUMP_DET) pipeline.log.info( "Status of the step: linearity " + str(model.meta.cal_step.assign_wcs) diff --git a/romancal/resample/resample_utils.py b/romancal/resample/resample_utils.py index e499d2840..75fe46627 100644 --- a/romancal/resample/resample_utils.py +++ b/romancal/resample/resample_utils.py @@ -189,7 +189,9 @@ def build_mask(dqarr, bitvalue): obtain the bit mask. - The resulting bit mask is returned as an ndarray of dtype `numpy.uint8`. """ - bitvalue = interpret_bit_flags(bitvalue, flag_name_map=pixel) + bitvalue = interpret_bit_flags( + bitvalue, flag_name_map={dq.name: dq.value for dq in pixel} + ) if bitvalue is None: return np.ones(dqarr.shape, dtype=np.uint8) diff --git a/romancal/saturation/saturation.py b/romancal/saturation/saturation.py index d8b0cb37d..87314e578 100644 --- a/romancal/saturation/saturation.py +++ b/romancal/saturation/saturation.py @@ -3,7 +3,7 @@ import logging import numpy as np -from roman_datamodels import dqflags +from roman_datamodels.dqflags import pixel from stcal.saturation.saturation import flag_saturated_pixels log = logging.getLogger(__name__) @@ -54,7 +54,7 @@ def flag_saturation(input_model, ref_model): sat_thresh, sat_dq, ATOD_LIMIT, - dqflags.pixel, + pixel, n_pix_grow_sat=0, read_pattern=input_model.meta.exposure.read_pattern, ) diff --git a/romancal/saturation/tests/test_saturation.py b/romancal/saturation/tests/test_saturation.py index 265aa658b..76c0eee3d 100644 --- a/romancal/saturation/tests/test_saturation.py +++ b/romancal/saturation/tests/test_saturation.py @@ -7,10 +7,10 @@ import numpy as np import pytest from astropy import units as u -from roman_datamodels import dqflags, maker_utils +from roman_datamodels import maker_utils from roman_datamodels.datamodels import ScienceRawModel +from roman_datamodels.dqflags import group, pixel -from romancal.lib import dqflags from romancal.saturation import SaturationStep from romancal.saturation.saturation import flag_saturation @@ -41,7 +41,7 @@ def test_basic_saturation_flagging(setup_wfi_datamodels): # Make sure that groups with signal > saturation limit get flagged satindex = np.argmax(output.data.value[:, 5, 5] == satvalue) - assert np.all(output.groupdq[satindex:, 5, 5] == dqflags.group["SATURATED"]) + assert np.all(output.groupdq[satindex:, 5, 5] == group.SATURATED) def test_read_pattern_saturation_flagging(setup_wfi_datamodels): @@ -84,7 +84,7 @@ def test_read_pattern_saturation_flagging(setup_wfi_datamodels): output = flag_saturation(ramp, satmap) # Make sure that groups after the third get flagged - assert np.all(output.groupdq[2:, 5, 5] == dqflags.group["SATURATED"]) + assert np.all(output.groupdq[2:, 5, 5] == group.SATURATED) def test_ad_floor_flagging(setup_wfi_datamodels): @@ -116,10 +116,7 @@ def test_ad_floor_flagging(setup_wfi_datamodels): output = flag_saturation(ramp, satmap) # Check if the right frames are flagged as saturated - assert np.all( - output.groupdq[satindxs, 5, 5] - == dqflags.group["DO_NOT_USE"] | dqflags.group["AD_FLOOR"] - ) + assert np.all(output.groupdq[satindxs, 5, 5] == group.DO_NOT_USE | group.AD_FLOOR) def test_ad_floor_and_saturation_flagging(setup_wfi_datamodels): @@ -155,12 +152,9 @@ def test_ad_floor_and_saturation_flagging(setup_wfi_datamodels): output = flag_saturation(ramp, satmap) # Check if the right frames are flagged as ad_floor - assert np.all( - output.groupdq[floorindxs, 5, 5] - == dqflags.group["DO_NOT_USE"] | dqflags.group["AD_FLOOR"] - ) + assert np.all(output.groupdq[floorindxs, 5, 5] == group.DO_NOT_USE | group.AD_FLOOR) # Check if the right frames are flagged as saturated - assert np.all(output.groupdq[satindxs, 5, 5] == dqflags.group["SATURATED"]) + assert np.all(output.groupdq[satindxs, 5, 5] == group.SATURATED) def test_signal_fluctuation_flagging(setup_wfi_datamodels): @@ -191,7 +185,7 @@ def test_signal_fluctuation_flagging(setup_wfi_datamodels): # Make sure that all groups after first saturated group are flagged satindex = np.argmax(output.data.value[:, 5, 5] == satvalue) - assert np.all(output.groupdq[satindex:, 5, 5] == dqflags.group["SATURATED"]) + assert np.all(output.groupdq[satindex:, 5, 5] == group.SATURATED) def test_all_groups_saturated(setup_wfi_datamodels): @@ -219,7 +213,7 @@ def test_all_groups_saturated(setup_wfi_datamodels): output = flag_saturation(ramp, satmap) # Make sure all groups are flagged - assert np.all(output.groupdq[:, 5, 5] == dqflags.group["SATURATED"]) + assert np.all(output.groupdq[:, 5, 5] == group.SATURATED) def test_dq_propagation(setup_wfi_datamodels): @@ -266,25 +260,22 @@ def test_no_sat_check(setup_wfi_datamodels): # Set saturation value in the saturation model & DQ value for NO_SAT_CHECK satmap.data[5, 5] = satvalue * satmap.data.unit - satmap.dq[5, 5] = dqflags.pixel["NO_SAT_CHECK"] + satmap.dq[5, 5] = pixel.NO_SAT_CHECK # Also set an existing DQ flag in input science data - ramp.pixeldq[5, 5] = dqflags.pixel["DO_NOT_USE"] + ramp.pixeldq[5, 5] = pixel.DO_NOT_USE # Run the pipeline output = flag_saturation(ramp, satmap) # Make sure output GROUPDQ does not get flagged as saturated # Make sure PIXELDQ is set to NO_SAT_CHECK and original flag - assert np.all(output.groupdq[:, 5, 5] != dqflags.group["SATURATED"]) + assert np.all(output.groupdq[:, 5, 5] != group.SATURATED) # Test that saturation bit is NOT set assert np.all( - output.groupdq[:, 5, 5] & (1 << dqflags.group["SATURATED"].bit_length() - 1) - == 0 - ) - assert output.pixeldq[5, 5] == ( - dqflags.pixel["NO_SAT_CHECK"] + dqflags.pixel["DO_NOT_USE"] + output.groupdq[:, 5, 5] & (1 << group.SATURATED.bit_length() - 1) == 0 ) + assert output.pixeldq[5, 5] == (pixel.NO_SAT_CHECK + pixel.DO_NOT_USE) def test_nans_in_mask(setup_wfi_datamodels): @@ -313,9 +304,9 @@ def test_nans_in_mask(setup_wfi_datamodels): output = flag_saturation(ramp, satmap) # Check that output GROUPDQ is not flagged as saturated - assert np.all(output.groupdq[:, 5, 5] != dqflags.group["SATURATED"]) + assert np.all(output.groupdq[:, 5, 5] != group.SATURATED) # Check that output PIXELDQ is set to NO_SAT_CHECK - assert output.pixeldq[5, 5] == dqflags.pixel["NO_SAT_CHECK"] + assert output.pixeldq[5, 5] == pixel.NO_SAT_CHECK def test_saturation_getbestref(setup_wfi_datamodels): diff --git a/romancal/skymatch/tests/test_skymatch.py b/romancal/skymatch/tests/test_skymatch.py index 5f7134baa..2a65d391b 100644 --- a/romancal/skymatch/tests/test_skymatch.py +++ b/romancal/skymatch/tests/test_skymatch.py @@ -7,16 +7,13 @@ from astropy.modeling import models from gwcs import coordinate_frames as cf from gwcs import wcs as gwcs_wcs -from roman_datamodels import dqflags from roman_datamodels.datamodels import ImageModel +from roman_datamodels.dqflags import pixel from roman_datamodels.maker_utils import mk_level2_image from romancal.datamodels.container import ModelContainer from romancal.skymatch import SkyMatchStep -DO_NOT_USE = dqflags.pixel["DO_NOT_USE"] -SATURATED = dqflags.pixel["SATURATED"] - def mk_gwcs(shape, sky_offset=[0, 0] * u.arcsec, rotate=0 * u.deg): # Example adapted from photutils: @@ -132,10 +129,10 @@ def _add_bad_pixels(im, sat_val, dont_use_val): im.data[-5:, -5:] = sat_val * im_unit im.data[:5, -5:] = sat_val * im_unit - im.dq[:5, :5] = SATURATED - im.dq[-5:, :5] = SATURATED - im.dq[-5:, -5:] = SATURATED - im.dq[:5, -5:] = SATURATED + im.dq[:5, :5] = pixel.SATURATED + im.dq[-5:, :5] = pixel.SATURATED + im.dq[-5:, -5:] = pixel.SATURATED + im.dq[:5, -5:] = pixel.SATURATED mask[:5, :5] = False mask[-5:, :5] = False @@ -148,7 +145,7 @@ def _add_bad_pixels(im, sat_val, dont_use_val): # center im.data[cx : cx + 10, cy : cy + 10] = dont_use_val * im_unit - im.dq[cx : cx + 10, cy : cy + 10] = DO_NOT_USE + im.dq[cx : cx + 10, cy : cy + 10] = pixel.DO_NOT_USE mask[cx : cx + 10, cy : cy + 10] = False return im, mask diff --git a/romancal/source_detection/source_detection_step.py b/romancal/source_detection/source_detection_step.py index 5b2122c58..a8fec1dcb 100644 --- a/romancal/source_detection/source_detection_step.py +++ b/romancal/source_detection/source_detection_step.py @@ -17,7 +17,8 @@ ) from photutils.detection import DAOStarFinder from roman_datamodels import datamodels as rdm -from roman_datamodels import dqflags, maker_utils +from roman_datamodels import maker_utils +from roman_datamodels.dqflags import pixel from romancal.lib import psf from romancal.stpipe import RomanStep @@ -85,9 +86,7 @@ def process(self, input): # mask DO_NOT_USE pixels - self.coverage_mask = ( - (dqflags.pixel["DO_NOT_USE"]) & input_model.dq - ).astype(bool) + self.coverage_mask = ((pixel.DO_NOT_USE) & input_model.dq).astype(bool) filt = input_model.meta.instrument["optical_element"]