Skip to content

Commit

Permalink
Smaps estimation module in mri-nufft (#90)
Browse files Browse the repository at this point in the history
* Added siemens and add_raw shifts

* Update src/mrinufft/io/siemens.py

Co-authored-by: Pierre-Antoine Comby <[email protected]>

* Update

* Fixed some more

* Moved codes around

* Added np.ndarray

* Fix movement

* Fix movement

* Fix flake

* ruff fix

* Fix

* Remove bymistake add

* ci: runs test only for non-style commit. (#73)

* Added fixSmaps

* Fixes updates

* Fix

* fix docs

* Added smaps with blurring

* Added doc

* Final touchups

* Added compute_smaps

* Added extra files

* Added compute_smaps

* Added mask

* Added Smaps

* Updates

* Added

* Fix

* Remove bymistake add

* Fix

* Fixed lint

* Lint

* Added refbackend

* Fix NDFT

* feat: use finufft as ref backend.

* feat(tests): move ndft vs nufft tests to own file.

* Added rebart

* Update codes

* updated mask

* Fixs

* PEP

* Add lint fixes

* Added PEP fixes

* Black

* Fix black

* Fix

* Added PSF weighting

* Move to tuple

* lint

* lint

---------

Co-authored-by: Pierre-Antoine Comby <[email protected]>
Co-authored-by: Pierre-Antoine Comby <[email protected]>
Co-authored-by: chaithyagr <[email protected]>
  • Loading branch information
4 people authored May 24, 2024
1 parent da3b954 commit afb3843
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 32 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ finufft = ["finufft"]
pynfft = ["pynfft2", "cython<3.0.0"]
pynufft = ["pynufft"]
io = ["pymapvbvd"]
smaps = ["scikit-image"]

test = ["pytest<8.0.0", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases"]
dev = ["black", "isort", "ruff"]
Expand Down
10 changes: 10 additions & 0 deletions src/mrinufft/extras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Sensitivity map estimation methods."""

from .smaps import low_frequency
from .utils import get_smaps


__all__ = [
"low_frequency",
"get_smaps",
]
177 changes: 177 additions & 0 deletions src/mrinufft/extras/smaps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""SMaps module for sensitivity maps estimation."""

from mrinufft.density.utils import flat_traj
from mrinufft.operators.base import get_array_module
from .utils import register_smaps
import numpy as np
from typing import Tuple


def _extract_kspace_center(
kspace_data,
kspace_loc,
threshold=None,
density=None,
window_fun="ellipse",
):
r"""Extract k-space center and corresponding sampling locations.
The extracted center of the k-space, i.e. both the kspace locations and
kspace values. If the density compensators are passed, the corresponding
compensators for the center of k-space data will also be returned. The
return dtypes for density compensation and kspace data is same as input
Parameters
----------
kspace_data: numpy.ndarray
The value of the samples
kspace_loc: numpy.ndarray
The samples location in the k-space domain (between [-0.5, 0.5[)
threshold: tuple or float
The threshold used to extract the k_space center (between (0, 1])
window_fun: "Hann", "Hanning", "Hamming", or a callable, default None.
The window function to apply to the selected data. It is computed with
the center locations selected. Only works with circular mask.
If window_fun is a callable, it takes as input the array (n_samples x n_dims)
of sample positions and returns an array of n_samples weights to be
applied to the selected k-space values, before the smaps estimation.
Returns
-------
data_thresholded: ndarray
The k-space values in the center region.
center_loc: ndarray
The locations in the center region.
density_comp: ndarray, optional
The density compensation weights (if requested)
Notes
-----
The Hann (or Hanning) and Hamming windows of width :math:`2\theta` are defined as:
.. math::
w(x,y) = a_0 - (1-a_0) * \cos(\pi * \sqrt{x^2+y^2}/\theta),
\sqrt{x^2+y^2} \le \theta
In the case of Hann window :math:`a_0=0.5`.
For Hamming window we consider the optimal value in the equiripple sense:
:math:`a_0=0.53836`.
.. Wikipedia:: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
"""
xp = get_array_module(kspace_data)
if isinstance(threshold, float):
threshold = (threshold,) * kspace_loc.shape[1]

if window_fun == "rect":
data_ordered = xp.copy(kspace_data)
index = xp.linspace(
0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64
)
condition = xp.logical_and.reduce(
tuple(
xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold))
)
)
index = xp.extract(condition, index)
center_locations = kspace_loc[index, :]
data_thresholded = data_ordered[:, index]
dc = density[index]
return data_thresholded, center_locations, dc
else:
if callable(window_fun):
window = window_fun(center_locations)
else:
if window_fun in ["hann", "hanning", "hamming"]:
radius = xp.linalg.norm(kspace_loc, axis=1)
a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836
window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold)
elif window_fun == "ellipse":
window = xp.sum(kspace_loc**2 / xp.asarray(threshold) ** 2, axis=1) <= 1
else:
raise ValueError("Unsupported window function.")
data_thresholded = window * kspace_data
# Return k-space locations & density just for consistency
return data_thresholded, kspace_loc, density


@register_smaps
@flat_traj
def low_frequency(
traj,
shape,
kspace_data,
backend,
threshold: float | Tuple[float, ...] = 0.1,
density=None,
window_fun: str = "ellipse",
blurr_factor: float = 0,
mask: bool = False,
):
"""
Calculate low-frequency sensitivity maps.
Parameters
----------
traj : numpy.ndarray
The trajectory of the samples.
shape : tuple
The shape of the image.
kspace_data : numpy.ndarray
The k-space data.
threshold : float, or tuple of float, optional
The threshold used for extracting the k-space center.
By default it is 0.1
backend : str
The backend used for the operator.
density : numpy.ndarray, optional
The density compensation weights.
window_fun: "Hann", "Hanning", "Hamming", or a callable, default None.
The window function to apply to the selected data. It is computed with
the center locations selected. Only works with circular mask.
If window_fun is a callable, it takes as input the array (n_samples x n_dims)
of sample positions and returns an array of n_samples weights to be
applied to the selected k-space values, before the smaps estimation.
blurr_factor : float, optional
The blurring factor for smoothing the sensitivity maps.
mask: bool, optional default `False`
Whether the Sensitivity maps must be masked
Returns
-------
Smaps : numpy.ndarray
The low-frequency sensitivity maps.
SOS : numpy.ndarray
The sum of squares of the sensitivity maps.
"""
# defer import to later to prevent circular import
from mrinufft import get_operator
from skimage.filters import threshold_otsu, gaussian
from skimage.morphology import convex_hull_image

k_space, samples, dc = _extract_kspace_center(
kspace_data=kspace_data,
kspace_loc=traj,
threshold=threshold,
density=density,
window_fun=window_fun,
)
smaps_adj_op = get_operator(backend)(
samples, shape, density=dc, n_coils=k_space.shape[0]
)
Smaps = smaps_adj_op.adj_op(k_space)
SOS = np.linalg.norm(Smaps, axis=0)
if mask:
thresh = threshold_otsu(SOS)
# Create convex hull from mask
convex_hull = convex_hull_image(SOS > thresh)
Smaps = Smaps * convex_hull
# Smooth out the sensitivity maps
if blurr_factor > 0:
Smaps = gaussian(Smaps, sigma=blurr_factor * np.asarray(shape))
# Re-normalize the sensitivity maps
if mask or blurr_factor > 0:
# ReCalculate SOS with a minor eps to ensure divide by 0 is ok
SOS = np.linalg.norm(Smaps, axis=0) + 1e-10
Smaps = Smaps / SOS
return Smaps, SOS
20 changes: 20 additions & 0 deletions src/mrinufft/extras/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Utils for extras module."""

from mrinufft._utils import MethodRegister

register_smaps = MethodRegister("sensitivity_maps")


def get_smaps(name, *args, **kwargs):
"""Get the density compensation function from its name."""
try:
method = register_smaps.registry["sensitivity_maps"][name]
except KeyError as e:
raise ValueError(
f"Unknown density compensation method {name}. Available methods are \n"
f"{list(register_smaps.registry['sensitivity_maps'].keys())}"
) from e

if args or kwargs:
return method(*args, **kwargs)
return method
5 changes: 4 additions & 1 deletion src/mrinufft/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Input/Output module for trajectories and data."""

from .cfl import traj2cfl, cfl2traj
from .nsp import read_trajectory, write_trajectory
from .nsp import read_trajectory, write_trajectory, read_arbgrad_rawdat
from .siemens import read_siemens_rawdat


__all__ = [
"traj2cfl",
"cfl2traj",
"read_trajectory",
"write_trajectory",
"read_arbgrad_rawdat",
"read_siemens_rawdat",
]
33 changes: 6 additions & 27 deletions src/mrinufft/io/nsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from datetime import datetime
from array import array
from .siemens import read_siemens_rawdat

from mrinufft.trajectories.utils import (
KMAX,
Expand Down Expand Up @@ -392,7 +393,7 @@ def read_trajectory(
return kspace_loc, params


def read_siemens_rawdat(
def read_arbgrad_rawdat(
filename: str,
removeOS: bool = False,
squeeze: bool = True,
Expand Down Expand Up @@ -429,32 +430,10 @@ def read_siemens_rawdat(
You can install it using the following command:
`pip install pymapVBVD`
"""
try:
from mapvbvd import mapVBVD
except ImportError as err:
raise ImportError(
"The mapVBVD module is not available. Please install it using "
"the following command: pip install pymapVBVD"
) from err
twixObj = mapVBVD(filename)
if isinstance(twixObj, list):
twixObj = twixObj[-1]
twixObj.image.flagRemoveOS = removeOS
twixObj.image.squeeze = squeeze
raw_kspace = twixObj.image[""]
data = np.moveaxis(raw_kspace, 0, 2)
hdr = {
"n_coils": int(twixObj.image.NCha),
"n_shots": int(twixObj.image.NLin),
"n_contrasts": int(twixObj.image.NSet),
"n_adc_samples": int(twixObj.image.NCol),
"n_slices": int(twixObj.image.NSli),
}
data = data.reshape(
hdr["n_coils"],
hdr["n_shots"] * hdr["n_adc_samples"],
hdr["n_slices"],
hdr["n_contrasts"],
data, hdr, twixObj = read_siemens_rawdat(
filename=filename,
removeOS=removeOS,
squeeze=squeeze,
)
if "ARBGRAD_VE11C" in data_type:
hdr["type"] = "ARBGRAD_GRE"
Expand Down
74 changes: 74 additions & 0 deletions src/mrinufft/io/siemens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Siemens specific rawdat reader, wrapper over pymapVBVD."""

import numpy as np


def read_siemens_rawdat(
filename: str,
removeOS: bool = False,
squeeze: bool = True,
return_twix: bool = True,
): # pragma: no cover
"""Read raw data from a Siemens MRI file.
Parameters
----------
filename : str
The path to the Siemens MRI file.
removeOS : bool, optional
Whether to remove the oversampling, by default False.
squeeze : bool, optional
Whether to squeeze the dimensions of the data, by default True.
data_type : str, optional
The type of data to read, by default 'ARBGRAD_VE11C'.
return_twix : bool, optional
Whether to return the twix object, by default True.
Returns
-------
data: ndarray
Imported data formatted as n_coils X n_samples X n_slices X n_contrasts
hdr: dict
Extra information about the data parsed from the twix file
Raises
------
ImportError
If the mapVBVD module is not available.
Notes
-----
This function requires the mapVBVD module to be installed.
You can install it using the following command:
`pip install pymapVBVD`
"""
try:
from mapvbvd import mapVBVD
except ImportError as err:
raise ImportError(
"The mapVBVD module is not available. Please install it using "
"the following command: pip install pymapVBVD"
) from err
twixObj = mapVBVD(filename)
if isinstance(twixObj, list):
twixObj = twixObj[-1]
twixObj.image.flagRemoveOS = removeOS
twixObj.image.squeeze = squeeze
raw_kspace = twixObj.image[""]
data = np.moveaxis(raw_kspace, 0, 2)
hdr = {
"n_coils": int(twixObj.image.NCha),
"n_shots": int(twixObj.image.NLin),
"n_contrasts": int(twixObj.image.NSet),
"n_adc_samples": int(twixObj.image.NCol),
"n_slices": int(twixObj.image.NSli),
}
data = data.reshape(
hdr["n_coils"],
hdr["n_shots"] * hdr["n_adc_samples"],
hdr["n_slices"],
hdr["n_contrasts"],
)
if return_twix:
return data, hdr, twixObj
return data, hdr
Loading

0 comments on commit afb3843

Please sign in to comment.