Skip to content

Commit

Permalink
Use dot access instead of dict access for dqflags
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJamieson committed Feb 9, 2024
1 parent 7dd57e1 commit 3d73827
Show file tree
Hide file tree
Showing 18 changed files with 95 additions and 142 deletions.
7 changes: 3 additions & 4 deletions romancal/dq_init/dq_init_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import numpy as np
import roman_datamodels as rdm
from roman_datamodels import dqflags, maker_utils
from roman_datamodels import maker_utils
from roman_datamodels.datamodels import RampModel
from roman_datamodels.dqflags import pixel

from romancal.dq_init import dq_initialization
from romancal.stpipe import RomanStep
Expand Down Expand Up @@ -74,9 +75,7 @@ def process(self, input):
x_start = input_model.meta.guidestar.gw_window_xstart
x_end = input_model.meta.guidestar.gw_window_xsize + x_start
# set pixeldq array to GW_AFFECTED_DATA (2**4) for the given range
output_model.pixeldq[int(x_start) : int(x_end), :] = dqflags.pixel[
"GW_AFFECTED_DATA"
]
output_model.pixeldq[int(x_start) : int(x_end), :] = pixel.GW_AFFECTED_DATA
self.log.info(
f"Flagging rows from: {x_start} to {x_end} as affected by guide window read"
)
Expand Down
41 changes: 17 additions & 24 deletions romancal/dq_init/tests/test_dq_init.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import pytest
from astropy import units as u
from roman_datamodels import dqflags, maker_utils, stnode
from roman_datamodels import maker_utils, stnode
from roman_datamodels.datamodels import MaskRefModel, ScienceRawModel
from roman_datamodels.dqflags import pixel

from romancal.dq_init import DQInitStep
from romancal.dq_init.dq_initialization import do_dqinit
Expand Down Expand Up @@ -55,18 +56,16 @@ def test_dq_im(xstart, ystart, xsize, ysize, ngroups, instrument, exp_type):

# assert that the pixels read back in match the mapping from ref data to
# science data
assert dqdata[100, 100] == dqflags.pixel["SATURATED"]
assert dqdata[200, 100] == dqflags.pixel["JUMP_DET"]
assert dqdata[300, 100] == dqflags.pixel["DROPOUT"]
assert dqdata[400, 100] == dqflags.pixel["PERSISTENCE"]
assert dqdata[500, 100] == dqflags.pixel["DO_NOT_USE"]
assert dqdata[600, 100] == dqflags.pixel["GW_AFFECTED_DATA"]
assert dqdata[100, 200] == dqflags.pixel["SATURATED"] + dqflags.pixel["DO_NOT_USE"]
assert dqdata[200, 200] == dqflags.pixel["JUMP_DET"] + dqflags.pixel["DO_NOT_USE"]
assert dqdata[300, 200] == dqflags.pixel["DROPOUT"] + dqflags.pixel["DO_NOT_USE"]
assert (
dqdata[400, 200] == dqflags.pixel["PERSISTENCE"] + dqflags.pixel["DO_NOT_USE"]
)
assert dqdata[100, 100] == pixel.SATURATED
assert dqdata[200, 100] == pixel.JUMP_DET
assert dqdata[300, 100] == pixel.DROPOUT
assert dqdata[400, 100] == pixel.PERSISTENCE
assert dqdata[500, 100] == pixel.DO_NOT_USE
assert dqdata[600, 100] == pixel.GW_AFFECTED_DATA
assert dqdata[100, 200] == pixel.SATURATED + pixel.DO_NOT_USE
assert dqdata[200, 200] == pixel.JUMP_DET + pixel.DO_NOT_USE
assert dqdata[300, 200] == pixel.DROPOUT + pixel.DO_NOT_USE
assert dqdata[400, 200] == pixel.PERSISTENCE + pixel.DO_NOT_USE


def test_groupdq():
Expand Down Expand Up @@ -171,15 +170,9 @@ def test_dq_add1_groupdq():

# test if pixels in pixeldq were incremented in value by 1
# check that previous dq flag is added to mask value
assert (
outfile.pixeldq[505, 505]
== dqflags.pixel["JUMP_DET"] + dqflags.pixel["DO_NOT_USE"]
)
assert outfile.pixeldq[505, 505] == pixel.JUMP_DET + pixel.DO_NOT_USE
# check two flags propagate correctly
assert (
outfile.pixeldq[400, 500]
== dqflags.pixel["SATURATED"] + dqflags.pixel["DO_NOT_USE"]
)
assert outfile.pixeldq[400, 500] == pixel.SATURATED + pixel.DO_NOT_USE


@pytest.mark.parametrize(
Expand Down Expand Up @@ -303,7 +296,7 @@ def test_dqinit_resultantdq(instrument, exptype):
wfi_sci_raw.meta["guidestar"]["gw_window_xstart"] = 1012
wfi_sci_raw.meta["guidestar"]["gw_window_xsize"] = 16
wfi_sci_raw.meta.exposure.type = exptype
wfi_sci_raw.resultantdq[1, 12, 12] = dqflags.pixel["DROPOUT"]
wfi_sci_raw.resultantdq[1, 12, 12] = pixel["DROPOUT"]
wfi_sci_raw.data = u.Quantity(
np.ones(shape, dtype=np.uint16), u.DN, dtype=np.uint16
)
Expand All @@ -330,8 +323,8 @@ def test_dqinit_resultantdq(instrument, exptype):
# check to see the resultantdq is the correct shape
assert wfi_sci_raw_model.resultantdq.shape == shape
# check to see the resultantdq & groupdq have the correct value
assert wfi_sci_raw_model.resultantdq[1, 12, 12] == dqflags.pixel["DROPOUT"]
assert result.groupdq[1, 12, 12] == dqflags.pixel["DROPOUT"]
assert wfi_sci_raw_model.resultantdq[1, 12, 12] == pixel["DROPOUT"]
assert result.groupdq[1, 12, 12] == pixel["DROPOUT"]


@pytest.mark.parametrize(
Expand Down
10 changes: 4 additions & 6 deletions romancal/flatfield/flat_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
from astropy import units as u
from roman_datamodels import dqflags
from roman_datamodels.dqflags import pixel

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -90,17 +90,15 @@ def apply_flat_field(science, flat):
# Find pixels in the flat that have a value of NaN and set
# their DQ to NO_FLAT_FIELD
flat_nan = np.isnan(flat_data)
flat_dq[flat_nan] = np.bitwise_or(flat_dq[flat_nan], dqflags.pixel["NO_FLAT_FIELD"])
flat_dq[flat_nan] = np.bitwise_or(flat_dq[flat_nan], pixel.NO_FLAT_FIELD)

# Find pixels in the flat that have a value of zero, and set
# their DQ to NO_FLAT_FIELD
flat_zero = np.where(flat_data == 0.0)
flat_dq[flat_zero] = np.bitwise_or(
flat_dq[flat_zero], dqflags.pixel["NO_FLAT_FIELD"]
)
flat_dq[flat_zero] = np.bitwise_or(flat_dq[flat_zero], pixel.NO_FLAT_FIELD)

# Find all pixels in the flat that have a DQ value of NO_FLAT_FIELD
flat_bad = np.bitwise_and(flat_dq, dqflags.pixel["NO_FLAT_FIELD"])
flat_bad = np.bitwise_and(flat_dq, pixel.NO_FLAT_FIELD)

# Reset the flat value of all bad pixels to 1.0, so that no
# correction is made
Expand Down
12 changes: 6 additions & 6 deletions romancal/jump/jump_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
from roman_datamodels import datamodels as rdd
from roman_datamodels import dqflags
from roman_datamodels.dqflags import group, pixel
from stcal.jump.jump import detect_jumps

from romancal.stpipe import RomanStep
Expand Down Expand Up @@ -118,11 +118,11 @@ def process(self, input):
# separate PR
dqflags_d = {} # Dict of DQ flags
dqflags_d = {
"GOOD": dqflags.group["GOOD"],
"DO_NOT_USE": dqflags.group["DO_NOT_USE"],
"SATURATED": dqflags.group["SATURATED"],
"JUMP_DET": dqflags.group["JUMP_DET"],
"NO_GAIN_VALUE": dqflags.pixel["NO_GAIN_VALUE"],
"GOOD": group.GOOD,
"DO_NOT_USE": group.DO_NOT_USE,
"SATURATED": group.SATURATED,
"JUMP_DET": group.JUMP_DET,
"NO_GAIN_VALUE": pixel.NO_GAIN_VALUE,
}

gdq, pdq, *_ = detect_jumps(
Expand Down
9 changes: 3 additions & 6 deletions romancal/lib/basic_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
"""General utility objects"""

import numpy as np
from roman_datamodels import dqflags
from roman_datamodels.datamodels import AssociationsModel

SATURATEDPIX = dqflags.pixel["SATURATED"]
SATURATEDGRP = dqflags.group["SATURATED"]
from roman_datamodels.dqflags import group, pixel


def bytes2human(n):
Expand Down Expand Up @@ -48,9 +45,9 @@ def is_fully_saturated(model):
Check to see if all data pixels are flagged as saturated.
"""

if np.all(np.bitwise_and(model.groupdq, SATURATEDGRP) == SATURATEDGRP):
if np.all(np.bitwise_and(model.groupdq, group.SATURATED) == group.SATURATED):

Check warning on line 48 in romancal/lib/basic_utils.py

View check run for this annotation

Codecov / codecov/patch

romancal/lib/basic_utils.py#L48

Added line #L48 was not covered by tests
return True
elif np.all(np.bitwise_and(model.pixeldq, SATURATEDPIX) == SATURATEDPIX):
elif np.all(np.bitwise_and(model.pixeldq, pixel.SATURATED) == pixel.SATURATED):

Check warning on line 50 in romancal/lib/basic_utils.py

View check run for this annotation

Codecov / codecov/patch

romancal/lib/basic_utils.py#L50

Added line #L50 was not covered by tests
return True

return False
Expand Down
4 changes: 2 additions & 2 deletions romancal/lib/psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SourceGrouper,
)
from roman_datamodels.datamodels import ImageModel
from roman_datamodels.dqflags import pixel as roman_dq_flag_enum
from roman_datamodels.dqflags import pixel
from webbpsf import conf, gridded_library, restart_logging

__all__ = [
Expand Down Expand Up @@ -388,7 +388,7 @@ def dq_to_boolean_mask(image_model_or_dq, ignore_flags=0, flag_map_name="ROMAN_D
dq = image_model_or_dq

# add the Roman DQ flags to the astropy bitmask registry:
dq_flag_map = {dq.name: dq.value for dq in roman_dq_flag_enum if dq.name != "GOOD"}
dq_flag_map = {dq.name: dq.value for dq in pixel if dq.name != "GOOD"}
bitmask.extend_bit_flag_map(flag_map_name, **dq_flag_map)

# convert the bitmask to a boolean mask:
Expand Down
4 changes: 2 additions & 2 deletions romancal/linearity/linearity_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from astropy import units as u
from roman_datamodels import datamodels as rdd
from roman_datamodels import dqflags
from roman_datamodels.dqflags import pixel
from stcal.linearity.linearity import linearity_correction

from romancal.stpipe import RomanStep
Expand Down Expand Up @@ -53,7 +53,7 @@ def process(self, input):
# The third return value is the procesed zero frame which
# Roman does not use.
new_data, new_pdq, _ = linearity_correction(
input_model.data.value, gdq, pdq, lin_coeffs, lin_dq, dqflags.pixel
input_model.data.value, gdq, pdq, lin_coeffs, lin_dq, pixel
)

input_model.data = u.Quantity(
Expand Down
13 changes: 6 additions & 7 deletions romancal/outlier_detection/outlier_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from astropy.units import Quantity
from drizzle.cdrizzle import tblot
from roman_datamodels import datamodels as rdm
from roman_datamodels import dqflags
from roman_datamodels.dqflags import pixel
from scipy import ndimage

from romancal.datamodels import ModelContainer
Expand All @@ -21,9 +21,6 @@
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

DO_NOT_USE = dqflags.pixel["DO_NOT_USE"]
OUTLIER = dqflags.pixel["OUTLIER"]


__all__ = ["OutlierDetection", "flag_cr", "abs_deriv"]

Expand Down Expand Up @@ -375,13 +372,15 @@ def flag_cr(
cr_mask = np.greater(diff_noise.value, snr1 * err_data.value)

# Count existing DO_NOT_USE pixels
count_existing = np.count_nonzero(sci_image.dq & DO_NOT_USE)
count_existing = np.count_nonzero(sci_image.dq & pixel.DO_NOT_USE)

# Update the DQ array in the input image.
sci_image.dq = np.bitwise_or(sci_image.dq, cr_mask * (DO_NOT_USE | OUTLIER))
sci_image.dq = np.bitwise_or(
sci_image.dq, cr_mask * (pixel.DO_NOT_USE | pixel.OUTLIER)
)

# Report number (and percent) of new DO_NOT_USE pixels found
count_outlier = np.count_nonzero(sci_image.dq & DO_NOT_USE)
count_outlier = np.count_nonzero(sci_image.dq & pixel.DO_NOT_USE)
count_added = count_outlier - count_existing
percent_cr = count_added / (sci_image.shape[0] * sci_image.shape[1]) * 100
log.info(f"New pixels flagged as outliers: {count_added} ({percent_cr:.2f}%)")
Expand Down
6 changes: 2 additions & 4 deletions romancal/pipeline/exposure_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from roman_datamodels import datamodels as rdm
from roman_datamodels import dqflags
from roman_datamodels.dqflags import group

import romancal.datamodels.filetype as filetype

Expand Down Expand Up @@ -236,9 +236,7 @@ def create_fully_saturated_zeroed_image(self, input_model):
input_model,
(
np.zeros(input_model.data.shape[1:], dtype=input_model.data.dtype),
input_model.pixeldq
| input_model.groupdq[0]
| dqflags.group["SATURATED"],
input_model.pixeldq | input_model.groupdq[0] | group.SATURATED,
np.zeros(input_model.err.shape[1:], dtype=input_model.err.dtype),
np.zeros(input_model.err.shape[1:], dtype=input_model.err.dtype),
np.zeros(input_model.err.shape[1:], dtype=input_model.err.dtype),
Expand Down
23 changes: 11 additions & 12 deletions romancal/ramp_fitting/ramp_fit_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import numpy as np
from astropy import units as u
from roman_datamodels import datamodels as rdd
from roman_datamodels import dqflags, maker_utils
from roman_datamodels import maker_utils
from roman_datamodels import stnode as rds
from roman_datamodels.dqflags import group, pixel
from stcal.ramp_fitting import ols_cas22_fit, ramp_fit
from stcal.ramp_fitting.ols_cas22 import Parameter, Variance

Expand Down Expand Up @@ -105,7 +106,7 @@ def ols(self, input_model, readnoise_model, gain_model):
self.algorithm,
self.weighting,
max_cores,
dqflags.pixel,
pixel,
)

if image_info is not None:
Expand Down Expand Up @@ -149,9 +150,7 @@ def ols(self, input_model, readnoise_model, gain_model):
# Image info order is: data, dq, var_poisson, var_rnoise, err
image_info = (
np.zeros(input_model.data.shape[2:], dtype=input_model.data.dtype),
input_model.pixeldq
| input_model.groupdq[0][0]
| dqflags.group["SATURATED"],
input_model.pixeldq | input_model.groupdq[0][0] | group.SATURATED,
np.zeros(input_model.err.shape[2:], dtype=input_model.err.dtype),
np.zeros(input_model.err.shape[2:], dtype=input_model.err.dtype),
np.zeros(input_model.err.shape[2:], dtype=input_model.err.dtype),
Expand Down Expand Up @@ -396,17 +395,17 @@ def get_pixeldq_flags(groupdq, pixeldq, slopes, err, gain):
"""
outpixeldq = pixeldq.copy()
# jump flagging
m = np.any(groupdq & dqflags.group["JUMP_DET"], axis=0)
outpixeldq |= (m * dqflags.pixel["JUMP_DET"]).astype(np.uint32)
m = np.any(groupdq & group.JUMP_DET, axis=0)
outpixeldq |= (m * pixel.JUMP_DET).astype(np.uint32)
# all saturated flagging
m = np.all(groupdq & dqflags.group["SATURATED"], axis=0)
outpixeldq |= (m * dqflags.pixel["SATURATED"]).astype(np.uint32)
m = np.all(groupdq & group.SATURATED, axis=0)
outpixeldq |= (m * pixel.SATURATED).astype(np.uint32)
# all either saturated or do not use or NaN slope flagging
satordnu = dqflags.group["SATURATED"] | dqflags.group["DO_NOT_USE"]
satordnu = group.SATURATED | group.DO_NOT_USE
m = np.all(groupdq & satordnu, axis=0)
m |= ~np.isfinite(slopes) | (err <= 0)
outpixeldq |= (m * dqflags.pixel["DO_NOT_USE"]).astype(np.uint32)
outpixeldq |= (m * pixel.DO_NOT_USE).astype(np.uint32)
m = (gain < 0) | ~np.isfinite(gain)
outpixeldq |= (m * dqflags.pixel["NO_GAIN_VALUE"]).astype(np.uint32)
outpixeldq |= (m * pixel.NO_GAIN_VALUE).astype(np.uint32)

return outpixeldq
12 changes: 1 addition & 11 deletions romancal/ramp_fitting/tests/test_ramp_fit_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from astropy import units as u
from astropy.time import Time
from roman_datamodels import dqflags, maker_utils
from roman_datamodels import maker_utils
from roman_datamodels.datamodels import GainRefModel, RampModel, ReadnoiseRefModel

from romancal.ramp_fitting import RampFitStep
Expand All @@ -15,16 +15,6 @@
# Used to deconstruct the MultiAccum tables into integration times.
ROMAN_READ_TIME = 3.04

DO_NOT_USE = dqflags.group["DO_NOT_USE"]
JUMP_DET = dqflags.group["JUMP_DET"]
SATURATED = dqflags.group["SATURATED"]

dqflags = {
"DO_NOT_USE": 1,
"SATURATED": 2,
"JUMP_DET": 4,
}

# Basic resultant
#
# The read pattern is `[[1], [2], [3], [4]]`
Expand Down
Loading

0 comments on commit 3d73827

Please sign in to comment.