Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random sampling #220

Merged
merged 14 commits into from
Jan 30, 2025
28 changes: 14 additions & 14 deletions examples/GPU/example_fastMRI_UNet.py
Original file line number Diff line number Diff line change
@@ -4,30 +4,30 @@
Simple UNet model.
==================
This model is a simplified version of the U-Net architecture,
which is widely used for image segmentation tasks.
This is implemented in the proprietary FASTMRI package [fastmri]_.
The U-Net model consists of an encoder (downsampling path) and
a decoder (upsampling path) with skip connections between corresponding
layers in the encoder and decoder.
These skip connections help in retaining spatial information
This model is a simplified version of the U-Net architecture,
which is widely used for image segmentation tasks.
This is implemented in the proprietary FASTMRI package [fastmri]_.
The U-Net model consists of an encoder (downsampling path) and
a decoder (upsampling path) with skip connections between corresponding
layers in the encoder and decoder.
These skip connections help in retaining spatial information
that is lost during the downsampling process.
The primary purpose of this model is to perform image reconstruction tasks,
specifically for MRI images.
It takes an input MRI image and reconstructs it to improve the image quality
The primary purpose of this model is to perform image reconstruction tasks,
specifically for MRI images.
It takes an input MRI image and reconstructs it to improve the image quality
or to recover missing parts of the image.
This implementation of the UNet model was pulled from the FastMRI Facebook
repository, which is a collaborative research project aimed at advancing
This implementation of the UNet model was pulled from the FastMRI Facebook
repository, which is a collaborative research project aimed at advancing
the field of medical imaging using machine learning techniques.
.. math::
\mathbf{\hat{x}} = \mathrm{arg} \min_{\mathbf{x}} || \mathcal{U}_\mathbf{\theta}(\mathbf{y}) - \mathbf{x} ||_2^2
where :math:`\mathbf{\hat{x}}` is the reconstructed MRI image, :math:`\mathbf{x}` is the ground truth image,
where :math:`\mathbf{\hat{x}}` is the reconstructed MRI image, :math:`\mathbf{x}` is the ground truth image,
:math:`\mathbf{y}` is the input MRI image (e.g., k-space data), and :math:`\mathcal{U}_\mathbf{\theta}` is the U-Net model parameterized by :math:`\theta`.
.. warning::
4 changes: 2 additions & 2 deletions examples/GPU/example_learn_samples.py
Original file line number Diff line number Diff line change
@@ -5,15 +5,15 @@
======================
A small pytorch example to showcase learning k-space sampling patterns.
This example showcases the auto-diff capabilities of the NUFFT operator
This example showcases the auto-diff capabilities of the NUFFT operator
wrt to k-space trajectory in mri-nufft.
In this example, we solve the following optimization problem:
.. math::
\mathbf{\hat{K}} = \mathrm{arg} \min_{\mathbf{K}} || \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} \mathbf{x} - \mathbf{x} ||_2^2
where :math:`\mathcal{F}_\mathbf{K}` is the forward NUFFT operator and :math:`D_\mathbf{K}` is the density compensators for trajectory :math:`\mathbf{K}`, :math:`\mathbf{x}` is the MR image which is also the target image to be reconstructed.
.. warning::
10 changes: 5 additions & 5 deletions examples/GPU/example_learn_samples_multicoil.py
Original file line number Diff line number Diff line change
@@ -5,24 +5,24 @@
=========================================
A small pytorch example to showcase learning k-space sampling patterns.
This example showcases the auto-diff capabilities of the NUFFT operator
This example showcases the auto-diff capabilities of the NUFFT operator
wrt to k-space trajectory in mri-nufft.
Briefly, in this example we try to learn the k-space samples :math:`\mathbf{K}` for the following cost function:
.. math::
\mathbf{\hat{K}} = arg \min_{\mathbf{K}} || \sum_{\ell=1}^LS_\ell^* \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} x_\ell - \mathbf{x}_{sos} ||_2^2
\mathbf{\hat{K}} = arg \min_{\mathbf{K}} || \sum_{\ell=1}^LS_\ell^* \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} x_\ell - \mathbf{x}_{sos} ||_2^2
where :math:`S_\ell` is the sensitivity map for the :math:`\ell`-th coil, :math:`\mathcal{F}_\mathbf{K}` is the forward NUFFT operator and :math:`D_\mathbf{K}` is the density compensators for trajectory :math:`\mathbf{K}`, :math:`\mathbf{x}_\ell` is the image for the :math:`\ell`-th coil, and :math:`\mathbf{x}_{sos} = \sqrt{\sum_{\ell=1}^L x_\ell^2}` is the sum-of-squares image as target image to be reconstructed.
In this example, the forward NUFFT operator :math:`\mathcal{F}_\mathbf{K}` is implemented with `model.operator` while the SENSE operator :math:`model.sense_op` models the term :math:`\mathbf{A} = \sum_{\ell=1}^LS_\ell^* \mathcal{F}_\mathbf{K}^* D_\mathbf{K}`.
For our data, we use a 2D slice of a 3D MRI image from the BrainWeb dataset, and the sensitivity maps are simulated using the `birdcage_maps` function from `sigpy.mri`.
.. note::
To showcase the features of ``mri-nufft``, we use ``
"cufinufft"`` backend for ``model.operator`` without density compensation and ``"gpunufft"`` backend for ``model.sense_op`` with density compensation.
"cufinufft"`` backend for ``model.operator`` without density compensation and ``"gpunufft"`` backend for ``model.sense_op`` with density compensation.
.. warning::
This example only showcases the autodiff capabilities, the learned sampling pattern is not scanner compliant as the scanner gradients required to implement it violate the hardware constraints. In practice, a projection :math:`\Pi_\mathcal{Q}(\mathbf{K})` into the scanner constraints set :math:`\mathcal{Q}` is recommended (see [Proj]_). This is implemented in the proprietary SPARKLING package [Sparks]_. Users are encouraged to contact the authors if they want to use it.
"""
2 changes: 1 addition & 1 deletion examples/example_learn_samples_multires.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
.. math::
\mathbf{\hat{K}} = arg \min_{\mathbf{K}} || \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} \mathbf{x} - \mathbf{x} ||_2^2
where :math:`\mathcal{F}_\mathbf{K}` is the forward NUFFT operator,
:math:`D_\mathbf{K}` is the density compensator for trajectory :math:`\mathbf{K}`,
and :math:`\mathbf{x}` is the MR image which is also the target image to be reconstructed.
153 changes: 153 additions & 0 deletions examples/example_trajectory_tools.py
Original file line number Diff line number Diff line change
@@ -219,6 +219,159 @@
axes=(0, 2),
)

# %%
paquiteau marked this conversation as resolved.
Show resolved Hide resolved
# Stack Random
# -------------
#
# A direct extension of the stacking expansion is to distribute the stacks
# according to a random distribution over the :math:`k_z`-axis.
#
# Arguments:
# - ``trajectory (array)``: array of k-space coordinates of size
# :math:`(N_c, N_s, N_d)`
# - ``dim_size (int)``: size of the kspace in voxel units
# - ``center_prop (int or float)`` : number of line
# - ``acceleration (int)``: Acceleration factor
# - ``pdf (str or array)``: Probability density function for the random distribution
# - ``rng (int or np.random.Generator)``: Random number generator
# - ``order (int)``: Order of the shots in the stack


trajectory = tools.stack_random(
planar_trajectories["Spiral"],
dim_size=128,
center_prop=0.1,
accel=16,
pdf="uniform",
order="top-down",
rng=42,
)

show_trajectory(trajectory, figure_size=figure_size, one_shot=one_shot)

# %%
# ``trajectory (array)``
# ~~~~~~~~~~~~~~~~~~~~~~
# The main use case is to stack trajectories consisting of
# flat or thick planes that will match the image slices.
arguments = ["Radial", "Spiral", "2D Cones", "3D Cones"]
function = lambda x: tools.stack_random(
planar_trajectories[x],
dim_size=128,
center_prop=0.1,
accel=16,
pdf="gaussian",
order="top-down",
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)

# %%
# ``dim_size (int)``
# ~~~~~~~~~~~~~~~~~~
# Size of the k-space in voxel units over the stacking direction. It
# is used to normalize the stack positions, and is used with the ``accel``
# factor and ``center_prop`` to determine the number of stacks.
arguments = [32, 64, 128]
function = lambda x: tools.stack_random(
planar_trajectories["Spiral"],
dim_size=x,
center_prop=0.1,
accel=8,
pdf="gaussian",
order="top-down",
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)

# %%
# ``center_prop (int or float)``
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Number of lines to keep in the center of the k-space. It is used to determine
# the number of stacks and the acceleration factor, and to keep the center of
# the k-space with a higher density of shots. If a ``float`` this is a fraction
# of the total ``dim_size``. If ``int`` it is directly the number of lines.

arguments = [1, 5, 0.1, 0.5]
function = lambda x: tools.stack_random(
planar_trajectories["Spiral"],
dim_size=128,
center_prop=x,
accel=16,
pdf="uniform",
order="top-down",
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)


# %%
# ``accel (int)``
# ~~~~~~~~~~~~~~~
# Acceleration factor to subsample the outer region of the k-space.
# Note that the acceleration factor does not take into account the center lines.


arguments = [1, 4, 8, 16, 32]
function = lambda x: tools.stack_random(
planar_trajectories["Spiral"],
dim_size=128,
center_prop=0.1,
accel=x,
pdf="uniform",
order="top-down",
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)

# %%
# ``pdf (str or array)``
# ~~~~~~~~~~~~~~~~~~~~~~
# Probability density function for the sampling of the outer region. It can
# either be a string to use a known probability law ("gaussian" or "uniform") or
# "equispaced" for a coherent undersampling (like the one used in GRAPPA). It
# can also be a array, for using a customed density probability.
# In this case, it will be normalized so that ``sum(pdf) =1``.

dim_size = 128
arguments = [
"gaussian",
"uniform",
"equispaced",
np.arange(dim_size),
]
function = lambda x: tools.stack_random(
planar_trajectories["Spiral"],
dim_size=128,
center_prop=0.1,
accel=32,
pdf=x,
order="top-down",
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)

# %%
# ``order (str)``
# ~~~~~~~~~~~~~~~
# Determine the ordering of the shot in the trajectory.
# Accepeted values are "center-out", "top-down" or "random".
dim_size = 128
arguments = [
"center-out",
"random",
"top-down",
]
function = lambda x: tools.stack_random(
planar_trajectories["Spiral"],
dim_size=128,
center_prop=0.1,
accel=32,
pdf="uniform",
order=x,
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)

# %%
# Rotate
8 changes: 8 additions & 0 deletions src/mrinufft/trajectories/__init__.py
Original file line number Diff line number Diff line change
@@ -56,6 +56,12 @@
initialize_3D_wong_radial,
)

from .tools import (
stack_random,
get_random_loc_1d,
)


__all__ = [
# trajectories
"initialize_2D_radial",
@@ -88,7 +94,9 @@
"initialize_3D_random_walk",
"initialize_3D_travelling_salesman",
# tools
"get_random_loc_1d",
"stack",
"stack_random",
"rotate",
"precess",
"conify",
6 changes: 4 additions & 2 deletions src/mrinufft/trajectories/display.py
Original file line number Diff line number Diff line change
@@ -57,6 +57,8 @@ class displayConfig:
This can be any of the matplotlib colormaps, or a list of colors."""
one_shot_color: str = "k"
"""Matplotlib color for the highlighted shot, by default ``"k"`` (black)."""
one_shot_linewidth_factor: float = 2
"""Factor to multiply the linewidth of the highlighted shot, by default ``2``."""
paquiteau marked this conversation as resolved.
Show resolved Hide resolved
gradient_point_color: str = "r"
"""Matplotlib color for gradient constraint points, by default ``"r"`` (red)."""
slewrate_point_color: str = "b"
@@ -243,7 +245,7 @@ def display_2D_trajectory(
trajectory[shot_id, :, 0],
trajectory[shot_id, :, 1],
color=displayConfig.one_shot_color,
linewidth=2 * displayConfig.linewidth,
linewidth=displayConfig.one_shot_linewidth_factor * displayConfig.linewidth,
)

# Point out violated constraints if requested
@@ -379,7 +381,7 @@ def display_3D_trajectory(
trajectory[shot_id, :, 1],
trajectory[shot_id, :, 2],
color=displayConfig.one_shot_color,
linewidth=2 * displayConfig.linewidth,
linewidth=displayConfig.one_shot_linewidth_factor * displayConfig.linewidth,
)
trajectory = trajectory.reshape((-1, Nc, Ns, 3))

183 changes: 182 additions & 1 deletion src/mrinufft/trajectories/tools.py
Original file line number Diff line number Diff line change
@@ -5,9 +5,10 @@
import numpy as np
from numpy.typing import NDArray
from scipy.interpolate import CubicSpline, interp1d
from scipy.stats import norm

from .maths import Rv, Rx, Ry, Rz
from .utils import KMAX, initialize_tilt
from .utils import KMAX, initialize_tilt, VDSpdf, VDSorder

################
# DIRECT TOOLS #
@@ -784,3 +785,183 @@ def radialize_center(
if in_out:
return _radialize_in_out(trajectory, nb_samples)
return _radialize_center_out(trajectory, nb_samples)


#################
# Randomization #
#################


def _flip2center(mask_cols: list[int], center_value: int) -> np.ndarray:
"""
Reorder a list by starting by a center_position and alternating left/right.
Parameters
----------
mask_cols: list or np.array
List of columns to reorder.
center_pos: int
Position of the center column.
Returns
-------
np.array: reordered columns.
"""
center_pos = np.argmin(np.abs(np.array(mask_cols) - center_value))
mask_cols = list(mask_cols)
left = mask_cols[center_pos::-1]
right = mask_cols[center_pos + 1 :]
new_cols = []
while left or right:
if left:
new_cols.append(left.pop(0))
if right:
new_cols.append(right.pop(0))
return np.array(new_cols)


def get_random_loc_1d(
dim_size: int,
center_prop: float | int,
accel: float = 4,
pdf: Literal["uniform", "gaussian", "equispaced"] | NDArray = "uniform",
rng: int | np.random.Generator | None = None,
order: Literal["center-out", "top-down", "random"] = "center-out",
) -> NDArray:
"""Get slice index at a random position.
Parameters
----------
dim_size: int
Dimension size
center_prop: float or int
Proportion of center of kspace to continuouly sample
accel: float
Undersampling/Acceleration factor
pdf: str, optional
Probability density function for the remaining samples.
"gaussian" (default) or "uniform" or np.array
rng: int or np.random.Generator
random state
order: str
Order of the lines, "center-out" (default), "random" or "top-down"
Returns
-------
np.ndarray: array of size dim_size/accel.
"""
order = VDSorder(order)
pdf = VDSpdf(pdf) if isinstance(pdf, str) else pdf
if accel == 0 or accel == 1:
return np.arange(dim_size) # type: ignore
elif accel < 0:
raise ValueError("acceleration factor should be positive.")
elif isinstance(accel, float):
raise ValueError("acceleration factor should be an integer.")

indexes = list(range(dim_size))
Daval-G marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(center_prop, int):
center_prop = int(center_prop * dim_size)

center_start = (dim_size - center_prop) // 2
center_stop = (dim_size + center_prop) // 2
center_indexes = indexes[center_start:center_stop]
borders = np.asarray([*indexes[:center_start], *indexes[center_stop:]])

n_samples_borders = (dim_size - len(center_indexes)) // accel
if n_samples_borders < 1:
raise ValueError(
"acceleration factor, center_prop and dimension not compatible."
"Edges will not be sampled. "
)
rng = np.random.default_rng(rng) # get RNG from a seed or existing rng.

def _get_samples(p: np.typing.ArrayLike) -> list[int]:
p = p / np.sum(p) # automatic casting if needed
return list(rng.choice(borders, size=n_samples_borders, replace=False, p=p))

if isinstance(pdf, np.ndarray):
if len(pdf) == dim_size:
# extract the borders
p = pdf[borders]
elif len(pdf) == len(borders):
p = pdf
else:
raise ValueError("Invalid size for probability.")
sampled_in_border = _get_samples(p)

elif pdf == VDSpdf.GAUSSIAN:
p = norm.pdf(np.linspace(norm.ppf(0.001), norm.ppf(0.999), len(borders)))
sampled_in_border = _get_samples(p)
elif pdf == VDSpdf.UNIFORM:
p = np.ones(len(borders))
sampled_in_border = _get_samples(p)
elif pdf == VDSpdf.EQUISPACED:
sampled_in_border = list(borders[::accel])

else:
raise ValueError("Unsupported value for pdf use any of . ")
# TODO: allow custom pdf as argument (vector or function.)

line_locs = np.array(sorted(center_indexes + sampled_in_border))
# apply order of lines
if order == VDSorder.CENTER_OUT:
line_locs = _flip2center(sorted(line_locs), dim_size // 2)
elif order == VDSorder.RANDOM:
line_locs = rng.permutation(line_locs)
elif order == VDSorder.TOP_DOWN:
line_locs = np.array(sorted(line_locs))
else:
raise ValueError(f"Unknown direction '{order}'.")
return (line_locs / dim_size) * 2 * KMAX - KMAX # rescale to [-0.5,0.5]


def stack_random(
trajectory: NDArray,
dim_size: int,
center_prop: float | int = 0.0,
accel: float | int = 4,
pdf: Literal["uniform", "gaussian", "equispaced"] | NDArray = "uniform",
rng: int | np.random.Generator | None = None,
order: Literal["center-out", "top-down", "random"] = "center-out",
):
"""Stack a 2D trajectory with random location.
Parameters
----------
traj: np.ndarray
Existing 2D trajectory.
dim_size: int
Size of the k_z dimension
paquiteau marked this conversation as resolved.
Show resolved Hide resolved
paquiteau marked this conversation as resolved.
Show resolved Hide resolved
center_prop: int or float
Number of line or proportion of slice to sample in the center of the k-space
paquiteau marked this conversation as resolved.
Show resolved Hide resolved
accel: int
Undersampling/Acceleration factor
pdf: str or np.array
Probability density function for the remaining samples.
"uniform" (default), "gaussian" or np.array
rng: random state
paquiteau marked this conversation as resolved.
Show resolved Hide resolved
order: str
Order of the lines, "center-out" (default), "random" or "top-down"
Returns
-------
numpy.ndarray
The 3D trajectory stacked along the :math:`k_z` axis.
"""
line_locs = get_random_loc_1d(dim_size, center_prop, accel, pdf, rng, order)
if len(trajectory.shape) == 2:
Nc, Ns = 1, trajectory.shape[0]
else:
Nc, Ns = trajectory.shape[:2]

new_trajectory = np.zeros((len(line_locs), Nc, Ns, 3))
for i, loc in enumerate(line_locs):
new_trajectory[i, :, :, :2] = trajectory[..., :2]
if trajectory.shape[-1] == 3:
new_trajectory[i, :, :, 2] = trajectory[..., 2] + loc
else:
new_trajectory[i, :, :, 2] = loc

return new_trajectory.reshape(-1, Ns, 3)
31 changes: 29 additions & 2 deletions src/mrinufft/trajectories/utils.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,12 @@ class FloatEnum(float, Enum, metaclass=CaseInsensitiveEnumMeta):
pass


class StrEnum(str, Enum, metaclass=CaseInsensitiveEnumMeta):
"""An Enum for str that is case insensitive for its attributes."""

pass


class Gammas(FloatEnum):
"""Enumerate gyromagnetic ratios for common nuclei in MR."""

@@ -94,7 +100,7 @@ class NormShapes(FloatEnum):
OCTAHEDRON = L1


class Tilts(str, Enum):
class Tilts(StrEnum):
r"""Enumerate available tilts.
Notes
@@ -120,7 +126,7 @@ class Tilts(str, Enum):
MRI = MRI_GOLDEN


class Packings(str, Enum, metaclass=CaseInsensitiveEnumMeta):
class Packings(StrEnum):
"""Enumerate available packing method for shots.
It is mostly used for wave-CAIPI trajectory
@@ -146,6 +152,27 @@ class Packings(str, Enum, metaclass=CaseInsensitiveEnumMeta):
SPIRAL = FIBONACCI


#############################
# Variable Density Sampling #
#############################


class VDSorder(StrEnum):
"""Available ordering for variable density sampling."""

CENTER_OUT = "center-out"
RANDOM = "random"
TOP_DOWN = "top-down"
paquiteau marked this conversation as resolved.
Show resolved Hide resolved


class VDSpdf(StrEnum):
"""Available law for variable density sampling."""

GAUSSIAN = "gaussian"
UNIFORM = "uniform"
EQUISPACED = "equispaced"


###############
# CONSTRAINTS #
###############