Skip to content

Commit

Permalink
feat: Dask get_fccd_images
Browse files Browse the repository at this point in the history
  • Loading branch information
maffettone committed Apr 5, 2024
1 parent 4178d1c commit f182f7b
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 47 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var/
*.egg
*.eggs
doc/_build
venv/

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
53 changes: 53 additions & 0 deletions csxtools/fastccd/dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Tuple

import numpy as np
from dask.array import Array as DaskArray
from numpy.typing import ArrayLike

GAIN_8 = 0x0000
GAIN_2 = 0x8000
GAIN_1 = 0xC000
BAD_PIXEL = 0x2000
PIXEL_MASK = 0x1FFF


def correct_images(images: DaskArray, dark: ArrayLike, flat: ArrayLike, gain: Tuple[float, float, float]):
"""_summary_
Parameters
----------
images : DaskArray
Input array of images to correct of shape (N, y, x) where N is the
number of images and x and y are the image size.
dark : ArrayLike
Input array of dark images. This should be of shape (3, y, x).
dark[0] is the gain 8 (most sensitive setting) dark image with
dark[2] being the gain 1 (least sensitive) dark image.
flat : ArrayLike
Input array for the flatfield correction. This should be of shape
(y, x)
gain : Tuple[float, float, float]
These are the gain multiplication factors for the three different
gain settings
// Note GAIN_1 is the least sensitive setting which means we need to multiply the
// measured values by 8. Conversly GAIN_8 is the most sensitive and therefore only
// does not need a multiplier
"""

# Shape checking:
if dark.ndim != 3:
raise ValueError(f"Expected 3D array, got {dark.ndim}D array for darks")
if dark.shape[0] != 3:
raise ValueError(f"Expected 3 dark images, got {dark.shape[0]}")
if dark.shape[-2:] != images.shape[-2]:
raise ValueError(f"Dark images shape {dark.shape[-2:]} does not match images shape {images.shape[-2]}")
if flat.shape != images.shape[-2:]:
raise ValueError(f"Flatfield shape {flat.shape} does not match images shape {images.shape[-2]}")

corrected = np.where(images & BAD_PIXEL, np.NaN, images)
corrected = np.where(images & GAIN_1, flat * gain[-1] * (corrected - dark[-1, ...]), corrected)
corrected = np.where(images & GAIN_2, flat * gain[-2] * (corrected - dark[-2, ...]), corrected)
corrected = np.where(images & GAIN_8, flat * gain[-3] * (corrected - dark[-3, ...]), corrected)

return corrected
20 changes: 15 additions & 5 deletions csxtools/fastccd/images.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import logging
import time as ttime

import numpy as np

from ..ext import fastccd
import time as ttime
from .dask import correct_images as dask_correct_images

import logging
logger = logging.getLogger(__name__)


def correct_images(images, dark=None, flat=None, gain=(1, 4, 8)):
def correct_images(images, dark=None, flat=None, gain=(1, 4, 8), *, dask=False):
"""Subtract backgrond and gain correct images
This routine subtrtacts the backgrond and corrects the images
Expand All @@ -27,6 +30,11 @@ def correct_images(images, dark=None, flat=None, gain=(1, 4, 8)):
gain : tuple, optional
These are the gain multiplication factors for the three different
gain settings
dask : bool, optional
Do computation in dask instead of in C extension over full array.
This returns a DaskArray or DaskArrayClient with pending execution instead of a numpy array.
You can use the .compute() method to get the numpy array.
Default is False.
Returns
-------
Expand All @@ -49,8 +57,10 @@ def correct_images(images, dark=None, flat=None, gain=(1, 4, 8)):
else:
flat = np.asarray(flat, dtype=np.float32)

data = fastccd.correct_images(images.astype(np.uint16),
dark, flat, gain)
if dask:
data = dask_correct_images(images.astype(np.uint16), dark, flat, gain)
else:
data = fastccd.correct_images(images.astype(np.uint16), dark, flat, gain)
t = ttime.time() - t

logger.info("Corrected image stack in %.3f seconds", t)
Expand Down
36 changes: 36 additions & 0 deletions csxtools/image/dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Literal, Union

import dask.array as da
from dask.array import Array as DaskArray


def rotate90(images: DaskArray, sense: Union[Literal["cw"], Literal["ccw"]] = "cw") -> DaskArray:
"""
Rotate images by 90 degrees using Dask.
This whole function is a moot wrapper around `da.rot90` from Dask, but written
explicitly to match the old C code.
Parameters
----------
images : da.Array
Input Dask array of images to rotate of shape (N, y, x),
where N is the number of images and y, x are the image dimensions.
sense : str, optional
'cw' to rotate clockwise, 'ccw' to rotate anticlockwise.
Default is 'cw'.
Returns
-------
da.Array
The rotated images as a Dask array.
"""
# Rotate images. The axes (1, 2) specify the plane of rotation (y-x plane for each image).
# k controls the direction and repetitions of the rotation.
if sense == "ccw":
k = 1
elif sense == "cw":
k = -1
else:
raise ValueError("sense must be 'cw' or 'ccw'")
rotated_images = da.rot90(images, k=k, axes=(-2, -1))
return rotated_images
17 changes: 13 additions & 4 deletions csxtools/image/transform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ..ext import image as extimage
from .dask import rotate90 as dask_rotate90


def rotate90(a, sense='ccw'):
def rotate90(a, sense="ccw", *, dask=True):
"""Rotate a stack of images by 90 degrees
This routine rotates a stack of images by 90. The rotation is performed
Expand All @@ -14,6 +15,11 @@ def rotate90(a, sense='ccw'):
Input array to be rotated. This should be of shape (N, y, x).
sense : string
'cw' to rotate clockwise, 'ccw' to rotate anitclockwise
dask : bool, optional
Do computation in dask instead of in C extension over full array.
This returns a DaskArray or DaskArrayClient with pending execution instead of a numpy array.
You can use the .compute() method to get the numpy array.
Default is False.
Returns
-------
Expand All @@ -22,11 +28,14 @@ def rotate90(a, sense='ccw'):
"""

if sense == 'ccw':
if sense == "ccw":
sense = 1
elif sense == 'cw':
elif sense == "cw":
sense = 0
else:
raise ValueError("sense must be 'cw' or 'ccw'")

return extimage.rotate90(a, sense)
if dask:
return dask_rotate90(a, sense)
else:
return extimage.rotate90(a, sense)
69 changes: 31 additions & 38 deletions csxtools/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import numpy as np
import logging
import time as ttime

import numpy as np
from databroker.assets.handlers import AreaDetectorHDF5TimestampHandler

from .fastccd import correct_images
from .image import rotate90, stackmean
from .settings import detectors
from databroker.assets.handlers import AreaDetectorHDF5TimestampHandler

import logging
logger = logging.getLogger(__name__)


def get_fastccd_images(light_header, dark_headers=None,
flat=None, gain=(1, 4, 8), tag=None, roi=None):
def get_fastccd_images(light_header, dark_headers=None, flat=None, gain=(1, 4, 8), tag=None, roi=None, *, dask=False):
"""Retreive and correct FastCCD Images from associated headers
Retrieve FastCCD Images from databroker and correct for:
Expand Down Expand Up @@ -50,14 +50,17 @@ def get_fastccd_images(light_header, dark_headers=None,
coordinates of the upper-left corner and width and height of
the ROI: e.g., (x, y, w, h)
dask : bool, optional
Use dask for computation. Default is False.
Returns
-------
dask.array : corrected images
"""

if tag is None:
tag = detectors['fccd']
tag = detectors["fccd"]

# Now lets sort out the ROI
if roi is not None:
Expand All @@ -72,8 +75,7 @@ def get_fastccd_images(light_header, dark_headers=None,
logger.warning("Processing without dark images")
else:
if dark_headers[0] is None:
raise NotImplementedError("Use of header metadata to find dark"
" images is not implemented yet.")
raise NotImplementedError("Use of header metadata to find dark" " images is not implemented yet.")

# Read the images for the dark headers
t = ttime.time()
Expand All @@ -91,25 +93,20 @@ def get_fastccd_images(light_header, dark_headers=None,

tt = ttime.time()
b = bgnd_events.astype(dtype=np.uint16)
logger.info("Image conversion took %.3f seconds",
ttime.time() - tt)
logger.info("Image conversion took %.3f seconds", ttime.time() - tt)

b = correct_images(b, gain=(1, 1, 1))
tt = ttime.time()
b = stackmean(b)
logger.info("Mean of image stack took %.3f seconds",
ttime.time() - tt)
logger.info("Mean of image stack took %.3f seconds", ttime.time() - tt)

else:
if (i == 0):
logger.warning("Missing dark image"
" for gain setting 8")
elif (i == 1):
logger.warning("Missing dark image"
" for gain setting 2")
elif (i == 2):
logger.warning("Missing dark image"
" for gain setting 1")
if i == 0:
logger.warning("Missing dark image" " for gain setting 8")
elif i == 1:
logger.warning("Missing dark image" " for gain setting 2")
elif i == 2:
logger.warning("Missing dark image" " for gain setting 1")

dark.append(b)

Expand All @@ -125,7 +122,7 @@ def get_fastccd_images(light_header, dark_headers=None,
if flat is not None and roi is not None:
flat = _crop(flat, roi)

return _correct_fccd_images(events, bgnd, flat, gain)
return _correct_fccd_images(events, bgnd, flat, gain, dask=dask)


def get_images_to_4D(images, dtype=None):
Expand All @@ -147,8 +144,7 @@ def get_images_to_4D(images, dtype=None):
>>> a = get_images_to_4D(images, dtype=np.float32)
"""
im = np.array([np.asarray(im, dtype=dtype) for im in images],
dtype=dtype)
im = np.array([np.asarray(im, dtype=dtype) for im in images], dtype=dtype)
return im


Expand Down Expand Up @@ -183,9 +179,9 @@ def _get_images(header, tag, roi=None):
return images


def _correct_fccd_images(image, bgnd, flat, gain):
image = correct_images(image, bgnd, flat, gain)
image = rotate90(image, 'cw')
def _correct_fccd_images(image, bgnd, flat, gain, *, dask=False):
image = correct_images(image, bgnd, flat, gain, dask=dask)
image = rotate90(image, "cw", dask=dask)
return image


Expand All @@ -196,11 +192,11 @@ def _crop_images(image, roi):
def _crop(image, roi):
image_shape = image.shape
# Assuming ROI is specified in the "rotated" (correct) orientation
roi = [image_shape[-2]-roi[3], roi[0], image_shape[-1]-roi[1], roi[2]]
return image.T[roi[1]:roi[3], roi[0]:roi[2]].T
roi = [image_shape[-2] - roi[3], roi[0], image_shape[-1] - roi[1], roi[2]]
return image.T[roi[1] : roi[3], roi[0] : roi[2]].T


def get_fastccd_timestamps(header, tag='fccd_image'):
def get_fastccd_timestamps(header, tag="fccd_image"):
"""Return the FastCCD timestamps from the Areadetector Data File
Return a list of numpy arrays of the timestamps for the images as
Expand All @@ -218,8 +214,7 @@ def get_fastccd_timestamps(header, tag='fccd_image'):
list of arrays of the timestamps
"""
with header.db.reg.handler_context(
{'AD_HDF5': AreaDetectorHDF5TimestampHandler}):
with header.db.reg.handler_context({"AD_HDF5": AreaDetectorHDF5TimestampHandler}):
timestamps = list(header.data(tag))

return timestamps
Expand Down Expand Up @@ -259,9 +254,8 @@ def calculate_flatfield(image, limits=(0.6, 1.4)):
return flat



def get_fastccd_flatfield(light, dark, flat=None, limits=(0.6, 1.4), half_interval=False):
"""Calculate a flatfield from two headers
"""Calculate a flatfield from two headers
This routine calculates the flatfield using the
:func:calculate_flatfield() function after obtaining the images from
Expand All @@ -278,7 +272,7 @@ def get_fastccd_flatfield(light, dark, flat=None, limits=(0.6, 1.4), half_interv
limits : tuple limits used for returning corrected pixel flatfield
The tuple setting lower and upper bound. np.nan returned value is outside bounds
half_interval : boolean or tuple to perform calculation for only half of the FastCCD
Default is False. If True, then the hard-code portion is retained. Customize image
Default is False. If True, then the hard-code portion is retained. Customize image
manipulation using a tuple of length 2 for (row_start, row_stop).
Expand All @@ -291,16 +285,15 @@ def get_fastccd_flatfield(light, dark, flat=None, limits=(0.6, 1.4), half_interv
images = stackmean(images)
if half_interval:
if isinstance(half_interval, bool):
row_start, row_stop = (7, 486) #hard coded for the broken half of the fccd
row_start, row_stop = (7, 486) # hard coded for the broken half of the fccd
else:
row_start, row_stop = half_interval
print(row_start, row_stop)
images[:, row_start:row_stop] = np.nan
flat = calculate_flatfield(images, limits)
removed = np.sum(np.isnan(flat))
if removed != 0:
logger.warning("Flatfield correction removed %d pixels (%.2f %%)" %
(removed, removed * 100 / flat.size))
logger.warning("Flatfield correction removed %d pixels (%.2f %%)" % (removed, removed * 100 / flat.size))
return flat


Expand Down

0 comments on commit f182f7b

Please sign in to comment.