-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Smaps estimation module in mri-nufft (#90)
* Added siemens and add_raw shifts * Update src/mrinufft/io/siemens.py Co-authored-by: Pierre-Antoine Comby <[email protected]> * Update * Fixed some more * Moved codes around * Added np.ndarray * Fix movement * Fix movement * Fix flake * ruff fix * Fix * Remove bymistake add * ci: runs test only for non-style commit. (#73) * Added fixSmaps * Fixes updates * Fix * fix docs * Added smaps with blurring * Added doc * Final touchups * Added compute_smaps * Added extra files * Added compute_smaps * Added mask * Added Smaps * Updates * Added * Fix * Remove bymistake add * Fix * Fixed lint * Lint * Added refbackend * Fix NDFT * feat: use finufft as ref backend. * feat(tests): move ndft vs nufft tests to own file. * Added rebart * Update codes * updated mask * Fixs * PEP * Add lint fixes * Added PEP fixes * Black * Fix black * Fix * Added PSF weighting * Move to tuple * lint * lint --------- Co-authored-by: Pierre-Antoine Comby <[email protected]> Co-authored-by: Pierre-Antoine Comby <[email protected]> Co-authored-by: chaithyagr <[email protected]>
- Loading branch information
1 parent
da3b954
commit afb3843
Showing
10 changed files
with
351 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Sensitivity map estimation methods.""" | ||
|
||
from .smaps import low_frequency | ||
from .utils import get_smaps | ||
|
||
|
||
__all__ = [ | ||
"low_frequency", | ||
"get_smaps", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
"""SMaps module for sensitivity maps estimation.""" | ||
|
||
from mrinufft.density.utils import flat_traj | ||
from mrinufft.operators.base import get_array_module | ||
from .utils import register_smaps | ||
import numpy as np | ||
from typing import Tuple | ||
|
||
|
||
def _extract_kspace_center( | ||
kspace_data, | ||
kspace_loc, | ||
threshold=None, | ||
density=None, | ||
window_fun="ellipse", | ||
): | ||
r"""Extract k-space center and corresponding sampling locations. | ||
The extracted center of the k-space, i.e. both the kspace locations and | ||
kspace values. If the density compensators are passed, the corresponding | ||
compensators for the center of k-space data will also be returned. The | ||
return dtypes for density compensation and kspace data is same as input | ||
Parameters | ||
---------- | ||
kspace_data: numpy.ndarray | ||
The value of the samples | ||
kspace_loc: numpy.ndarray | ||
The samples location in the k-space domain (between [-0.5, 0.5[) | ||
threshold: tuple or float | ||
The threshold used to extract the k_space center (between (0, 1]) | ||
window_fun: "Hann", "Hanning", "Hamming", or a callable, default None. | ||
The window function to apply to the selected data. It is computed with | ||
the center locations selected. Only works with circular mask. | ||
If window_fun is a callable, it takes as input the array (n_samples x n_dims) | ||
of sample positions and returns an array of n_samples weights to be | ||
applied to the selected k-space values, before the smaps estimation. | ||
Returns | ||
------- | ||
data_thresholded: ndarray | ||
The k-space values in the center region. | ||
center_loc: ndarray | ||
The locations in the center region. | ||
density_comp: ndarray, optional | ||
The density compensation weights (if requested) | ||
Notes | ||
----- | ||
The Hann (or Hanning) and Hamming windows of width :math:`2\theta` are defined as: | ||
.. math:: | ||
w(x,y) = a_0 - (1-a_0) * \cos(\pi * \sqrt{x^2+y^2}/\theta), | ||
\sqrt{x^2+y^2} \le \theta | ||
In the case of Hann window :math:`a_0=0.5`. | ||
For Hamming window we consider the optimal value in the equiripple sense: | ||
:math:`a_0=0.53836`. | ||
.. Wikipedia:: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows | ||
""" | ||
xp = get_array_module(kspace_data) | ||
if isinstance(threshold, float): | ||
threshold = (threshold,) * kspace_loc.shape[1] | ||
|
||
if window_fun == "rect": | ||
data_ordered = xp.copy(kspace_data) | ||
index = xp.linspace( | ||
0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64 | ||
) | ||
condition = xp.logical_and.reduce( | ||
tuple( | ||
xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold)) | ||
) | ||
) | ||
index = xp.extract(condition, index) | ||
center_locations = kspace_loc[index, :] | ||
data_thresholded = data_ordered[:, index] | ||
dc = density[index] | ||
return data_thresholded, center_locations, dc | ||
else: | ||
if callable(window_fun): | ||
window = window_fun(center_locations) | ||
else: | ||
if window_fun in ["hann", "hanning", "hamming"]: | ||
radius = xp.linalg.norm(kspace_loc, axis=1) | ||
a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836 | ||
window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold) | ||
elif window_fun == "ellipse": | ||
window = xp.sum(kspace_loc**2 / xp.asarray(threshold) ** 2, axis=1) <= 1 | ||
else: | ||
raise ValueError("Unsupported window function.") | ||
data_thresholded = window * kspace_data | ||
# Return k-space locations & density just for consistency | ||
return data_thresholded, kspace_loc, density | ||
|
||
|
||
@register_smaps | ||
@flat_traj | ||
def low_frequency( | ||
traj, | ||
shape, | ||
kspace_data, | ||
backend, | ||
threshold: float | Tuple[float, ...] = 0.1, | ||
density=None, | ||
window_fun: str = "ellipse", | ||
blurr_factor: float = 0, | ||
mask: bool = False, | ||
): | ||
""" | ||
Calculate low-frequency sensitivity maps. | ||
Parameters | ||
---------- | ||
traj : numpy.ndarray | ||
The trajectory of the samples. | ||
shape : tuple | ||
The shape of the image. | ||
kspace_data : numpy.ndarray | ||
The k-space data. | ||
threshold : float, or tuple of float, optional | ||
The threshold used for extracting the k-space center. | ||
By default it is 0.1 | ||
backend : str | ||
The backend used for the operator. | ||
density : numpy.ndarray, optional | ||
The density compensation weights. | ||
window_fun: "Hann", "Hanning", "Hamming", or a callable, default None. | ||
The window function to apply to the selected data. It is computed with | ||
the center locations selected. Only works with circular mask. | ||
If window_fun is a callable, it takes as input the array (n_samples x n_dims) | ||
of sample positions and returns an array of n_samples weights to be | ||
applied to the selected k-space values, before the smaps estimation. | ||
blurr_factor : float, optional | ||
The blurring factor for smoothing the sensitivity maps. | ||
mask: bool, optional default `False` | ||
Whether the Sensitivity maps must be masked | ||
Returns | ||
------- | ||
Smaps : numpy.ndarray | ||
The low-frequency sensitivity maps. | ||
SOS : numpy.ndarray | ||
The sum of squares of the sensitivity maps. | ||
""" | ||
# defer import to later to prevent circular import | ||
from mrinufft import get_operator | ||
from skimage.filters import threshold_otsu, gaussian | ||
from skimage.morphology import convex_hull_image | ||
|
||
k_space, samples, dc = _extract_kspace_center( | ||
kspace_data=kspace_data, | ||
kspace_loc=traj, | ||
threshold=threshold, | ||
density=density, | ||
window_fun=window_fun, | ||
) | ||
smaps_adj_op = get_operator(backend)( | ||
samples, shape, density=dc, n_coils=k_space.shape[0] | ||
) | ||
Smaps = smaps_adj_op.adj_op(k_space) | ||
SOS = np.linalg.norm(Smaps, axis=0) | ||
if mask: | ||
thresh = threshold_otsu(SOS) | ||
# Create convex hull from mask | ||
convex_hull = convex_hull_image(SOS > thresh) | ||
Smaps = Smaps * convex_hull | ||
# Smooth out the sensitivity maps | ||
if blurr_factor > 0: | ||
Smaps = gaussian(Smaps, sigma=blurr_factor * np.asarray(shape)) | ||
# Re-normalize the sensitivity maps | ||
if mask or blurr_factor > 0: | ||
# ReCalculate SOS with a minor eps to ensure divide by 0 is ok | ||
SOS = np.linalg.norm(Smaps, axis=0) + 1e-10 | ||
Smaps = Smaps / SOS | ||
return Smaps, SOS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
"""Utils for extras module.""" | ||
|
||
from mrinufft._utils import MethodRegister | ||
|
||
register_smaps = MethodRegister("sensitivity_maps") | ||
|
||
|
||
def get_smaps(name, *args, **kwargs): | ||
"""Get the density compensation function from its name.""" | ||
try: | ||
method = register_smaps.registry["sensitivity_maps"][name] | ||
except KeyError as e: | ||
raise ValueError( | ||
f"Unknown density compensation method {name}. Available methods are \n" | ||
f"{list(register_smaps.registry['sensitivity_maps'].keys())}" | ||
) from e | ||
|
||
if args or kwargs: | ||
return method(*args, **kwargs) | ||
return method |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
"""Input/Output module for trajectories and data.""" | ||
|
||
from .cfl import traj2cfl, cfl2traj | ||
from .nsp import read_trajectory, write_trajectory | ||
from .nsp import read_trajectory, write_trajectory, read_arbgrad_rawdat | ||
from .siemens import read_siemens_rawdat | ||
|
||
|
||
__all__ = [ | ||
"traj2cfl", | ||
"cfl2traj", | ||
"read_trajectory", | ||
"write_trajectory", | ||
"read_arbgrad_rawdat", | ||
"read_siemens_rawdat", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""Siemens specific rawdat reader, wrapper over pymapVBVD.""" | ||
|
||
import numpy as np | ||
|
||
|
||
def read_siemens_rawdat( | ||
filename: str, | ||
removeOS: bool = False, | ||
squeeze: bool = True, | ||
return_twix: bool = True, | ||
): # pragma: no cover | ||
"""Read raw data from a Siemens MRI file. | ||
Parameters | ||
---------- | ||
filename : str | ||
The path to the Siemens MRI file. | ||
removeOS : bool, optional | ||
Whether to remove the oversampling, by default False. | ||
squeeze : bool, optional | ||
Whether to squeeze the dimensions of the data, by default True. | ||
data_type : str, optional | ||
The type of data to read, by default 'ARBGRAD_VE11C'. | ||
return_twix : bool, optional | ||
Whether to return the twix object, by default True. | ||
Returns | ||
------- | ||
data: ndarray | ||
Imported data formatted as n_coils X n_samples X n_slices X n_contrasts | ||
hdr: dict | ||
Extra information about the data parsed from the twix file | ||
Raises | ||
------ | ||
ImportError | ||
If the mapVBVD module is not available. | ||
Notes | ||
----- | ||
This function requires the mapVBVD module to be installed. | ||
You can install it using the following command: | ||
`pip install pymapVBVD` | ||
""" | ||
try: | ||
from mapvbvd import mapVBVD | ||
except ImportError as err: | ||
raise ImportError( | ||
"The mapVBVD module is not available. Please install it using " | ||
"the following command: pip install pymapVBVD" | ||
) from err | ||
twixObj = mapVBVD(filename) | ||
if isinstance(twixObj, list): | ||
twixObj = twixObj[-1] | ||
twixObj.image.flagRemoveOS = removeOS | ||
twixObj.image.squeeze = squeeze | ||
raw_kspace = twixObj.image[""] | ||
data = np.moveaxis(raw_kspace, 0, 2) | ||
hdr = { | ||
"n_coils": int(twixObj.image.NCha), | ||
"n_shots": int(twixObj.image.NLin), | ||
"n_contrasts": int(twixObj.image.NSet), | ||
"n_adc_samples": int(twixObj.image.NCol), | ||
"n_slices": int(twixObj.image.NSli), | ||
} | ||
data = data.reshape( | ||
hdr["n_coils"], | ||
hdr["n_shots"] * hdr["n_adc_samples"], | ||
hdr["n_slices"], | ||
hdr["n_contrasts"], | ||
) | ||
if return_twix: | ||
return data, hdr, twixObj | ||
return data, hdr |
Oops, something went wrong.