diff --git a/examples/GPU/example_3d_trajectory_display.py b/examples/GPU/example_3d_trajectory_display.py new file mode 100644 index 00000000..4f59bda5 --- /dev/null +++ b/examples/GPU/example_3d_trajectory_display.py @@ -0,0 +1,136 @@ +""" +============================= +3D Trajectory Gridded display +============================= + +In this example, we show some tools available to display 3D trajectories. +It can be used to understand the k-space sampling patterns, visualize the trajectories, see the sampling times, gradient strengths, slew rates etc. +Another key feature is to display the sampling density in k-space, for example to check for k-space holes or irregularities in the learning-based trajectories that would lead to artifacts in the images. +""" + +# %% + +# Imports +from mrinufft.trajectories.display3D import get_gridded_trajectory +import mrinufft.trajectories.trajectory3D as mtt +from mrinufft.trajectories.utils import Gammas +import matplotlib.pyplot as plt +import numpy as np + + +# %% +# Helper function to Displaying 3D Gridded Trajectories +# ===================================================== +# Utility function to plot mid-plane slices for 3D volumes +def plot_slices(axs, volume, title=""): + def set_labels(ax, axis_num=None): + ax.set_xticks([0, 32, 64]) + ax.set_yticks([0, 32, 64]) + ax.set_xticklabels([r"$-\pi$", 0, r"$\pi$"]) + ax.set_yticklabels([r"$-\pi$", 0, r"$\pi$"]) + if axis_num is not None: + ax.set_xlabel(r"$k_" + "zxy"[axis_num] + r"$") + ax.set_ylabel(r"$k_" + "yzx"[axis_num] + r"$") + + for i in range(3): + volume = np.rollaxis(volume, i, 0) + axs[i].imshow(volume[volume.shape[0] // 2]) + axs[i].set_title( + ((title + f"\n") if i == 0 else "") + r"$k_{" + "xyz"[i] + r"}=0$" + ) + set_labels(axs[i], i) + + +# %% +# Helper function to Displaying 3D Trajectories +# ============================================= +# Helper function to showcase the features of `get_gridded_trajectory` function +# This function will first grid the trajectory using the `get_gridded_trajectory` +# function and then plot the mid-plane slices of the gridded trajectory. +def create_grid(grid_type, trajectories, traj_params, **kwargs): + fig, axs = plt.subplots(3, 3, figsize=(10, 10)) + plt.subplots_adjust(hspace=0.5) + for i, (name, traj) in enumerate(trajectories.items()): + grid = get_gridded_trajectory( + traj, + traj_params["img_size"], + grid_type=grid_type, + traj_params=traj_params, + **kwargs, + ) + plot_slices(axs[:, i], grid, title=name) + + +# %% +# Trajectories to display +# ======================= +# Create a bunch of sampling trajectories +trajectories = { + "Radial": mtt.initialize_3D_phyllotaxis_radial(64 * 8, 64), + "FLORET": mtt.initialize_3D_floret(64 * 8, 64, nb_revolutions=2), + "Seiffert Spirals": mtt.initialize_3D_seiffert_spiral(64 * 8, 64), +} +traj_params = { + "FOV": (0.23, 0.23, 0.23), + "img_size": (64, 64, 64), + "gamma": Gammas.HYDROGEN, +} + +# %% +# Sampling density +# ================= +# Display the density of the trajectories, along the 3 mid-planes. For this, make `grid_type="density"`. +create_grid("density", trajectories, traj_params) +plt.suptitle("Sampling Density", y=0.98, x=0.52, fontsize=20) +plt.show() + + +# %% +# Sampling time +# ============= +# Display the sampling times over the trajectories. For this, make `grid_type="time"`. +# It helps to check the sampling times over the k-space trajectories, which can be responsible for excessive off-resonance artifacts. +# Note that this is just a relative visualization of sample times on a colour scale, and the actual sampling time. +create_grid("time", trajectories, traj_params) +plt.suptitle("Sampling Time", y=0.98, x=0.52, fontsize=20) +plt.show() + +# %% +# Inversion time +# ============== +# Display the inversion time of the trajectories. For this, make `grid_type="inversion"`. +# This helps in obtaining the inversion time when particular region of k-space is sampled, assuming the trajectories are time ordered, +# and the argument `turbo_factor` is specified, which is the time between 2 inversion pulses. +create_grid("inversion", trajectories, traj_params, turbo_factor=64) +plt.suptitle("Inversion Time", y=0.98, x=0.52, fontsize=20) +plt.show() +# %% +# K-space holes +# ============= +# Display the k-space holes in the trajectories. For this, make `grid_type="holes"`. +# K-space holes are areas with missing trajectory coverage, and can typically occur with learning-based trajectories when optimized using a specific loss. +# This feature can be used to identify the k-space holes, which could lead to Gibbs-like ringing artifacts in the images. +create_grid("holes", trajectories, traj_params, threshold=1e-2) +plt.suptitle("K-space Holes", y=0.98, x=0.52, fontsize=20) +plt.show() +# %% +# Gradient strength +# ================= +# Display the gradient strength of the trajectories. For this, make `grid_type="gradients"`. +# This helps in displaying the gradient strength applied at specific k-space region, +# which can be used as a surrogate to k-space "velocity", i.e. how fast does trajectory pass through a given region in k-space. +# It could be useful while characterizing spatial SNR profile in k-space +create_grid("gradients", trajectories, traj_params) +plt.suptitle("Gradient Strength", y=0.98, x=0.52, fontsize=20) +plt.show() + +# %% +# Slew rates +# =========== +# Display the slew rates of the trajectories. For this, make `grid_type="slew"`. +# This helps in displaying the slew rates applied at specific k-space region, +# which can ne used as a surrogate to k-space "acceleration", i.e. how fast does trajectory change in a given region in k-space +# It could be useful to understand potential regions in k-space with eddy current artifacts and trajectories which could lead to peripheral nerve stimulations. +create_grid("slew", trajectories, traj_params) +plt.suptitle("Slew Rates", y=0.98, x=0.52, fontsize=20) +plt.show() diff --git a/pyproject.toml b/pyproject.toml index 4a44d4c8..30458c9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ autodiff = ["torch"] test = ["pytest<8.0.0", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases"] dev = ["black", "isort", "ruff"] -doc = ["sphinx-book-theme","sphinx-copybutton", "sphinx-gallery", "matplotlib", "pooch", "brainweb-dl"] +doc = ["sphinx-book-theme","sphinx-copybutton", "sphinx-gallery", "matplotlib", "pooch", "brainweb-dl", "coverage"] # pooch is for scipy.datasets [build-system] diff --git a/src/mrinufft/trajectories/display3D.py b/src/mrinufft/trajectories/display3D.py new file mode 100644 index 00000000..524202ac --- /dev/null +++ b/src/mrinufft/trajectories/display3D.py @@ -0,0 +1,148 @@ +"""Utils for displaying 3D trajectories.""" + +from mrinufft import get_operator, get_density +from mrinufft.trajectories.utils import ( + convert_trajectory_to_gradients, + convert_gradients_to_slew_rates, + KMAX, + DEFAULT_RASTER_TIME, +) +from mrinufft.density.utils import flat_traj +import numpy as np + + +def get_gridded_trajectory( + trajectory: np.ndarray, + shape: tuple, + grid_type: str = "density", + osf: int = 1, + backend: str = "gpunufft", + traj_params: dict = None, + turbo_factor: int = 176, + elliptical_samp: bool = True, + threshold: float = 1e-3, +): + """ + Compute various trajectory characteristics onto a grid. + + This function helps in gridding a k-space sampling trajectory to a desired shape, + allowing for easier viewing of the trajectory. + The gridding process can be carried out to reflect the sampling density, + sampling time, inversion time, k-space holes, gradient strengths, or slew rates. + Please check `grid_type` parameter to know the benefits of each type of gridding. + During the gridding process, the values corresponding to various samples within the + same voxel get averaged. + + Parameters + ---------- + trajectory : ndarray + The input array of shape (N, M, D), where N is the number of shots and M is the + number of samples per shot and D is the dimension of the trajectory (usually 3) + shape : tuple + The desired shape of the gridded trajectory. + grid_type : str, optional + The type of gridded trajectory to compute. Default is "density". + It can be one of the following: + "density" : Get the sampling density in closest number of samples per voxel. + Helps understand suboptimal sampling, by showcasing regions with strong + oversampling. + "time" : Showcases when the k-space data is acquired in time. + This is helpful to view and understand off-resonance effects. + Generally, lower off-resonance effects occur when the sampling trajectory + has smoother k-space sampling time over the k-space. + "inversion" : Relative inversion time at the sampling location. Needs + `turbo_factor` to be set. This is useful for analyzing the exact inversion + time when the k-space is acquired, for sequences like MP(2)RAGE. + "holes": Show the k-space missing coverage, or holes, within a ellipsoid of the + k-space. + "gradients": Show the gradient strengths of the k-space trajectory. + "slew": Show the slew rate of the k-space trajectory. + osf : int, optional + The oversampling factor for the gridded trajectory. Default is 1. + backend : str, optional + The backend to use for gridding. Default is "gpunufft". + Note that "gpunufft" is anyway used to get the `pipe` density internally. + traj_params : dict, optional + The trajectory parameters. Default is None. + This is only needed when `grid_type` is "gradients" or "slew". + The parameters needed include `img_size` (tuple), `FOV` (tuple in `m`), + and `gamma` (float in kHz/T) of the sequence. + Generally these values are stored in the header of the trajectory file. + turbo_factor : int, optional + The turbo factor when sampling is with inversion. Default is 176, which is + the default turbo factor for MPRAGE acquisitions at 1mm whole + brain acquisitions. + elliptical_samp : bool, optional + Whether the k-space corners should be expected to be covered + or ignored when `grid_type` is "holes". Ignoring them with `True` + corresponds to trajectories with spherical/elliptical sampling. + Default is `True`. + (i.e. ellipsoid over cuboid). Use this if the trajectory is expected to be + elliptical sampling of k-space, to avoid large holes between ellipsoid and + cuboid. + threshold: float, optional default 1e-3 + The threshold for the k-space holes in number of samples per voxel + This value is set heuristically to visualize the k-space hole. + + Returns + ------- + ndarray + The gridded trajectory of shape `shape`. + """ + samples = trajectory.reshape(-1, trajectory.shape[-1]) + dcomp = get_density("pipe")(trajectory, shape) + grid_op = get_operator(backend)( + trajectory, [sh * osf for sh in shape], density=dcomp, upsampfac=1 + ) + gridded_ones = grid_op.raw_op.adj_op(np.ones(samples.shape[0]), None, True) + if grid_type == "density": + return np.abs(gridded_ones).squeeze() + elif grid_type == "time": + data = grid_op.raw_op.adj_op( + np.tile(np.linspace(1, 10, trajectory.shape[1]), (trajectory.shape[0],)), + None, + True, + ) + elif grid_type == "inversion": + data = grid_op.raw_op.adj_op( + np.repeat( + np.linspace(1, 10, turbo_factor), + samples.shape[0] // turbo_factor + 1, + )[: samples.shape[0]], + None, + True, + ) + elif grid_type == "holes": + data = np.abs(gridded_ones).squeeze() < threshold + if elliptical_samp: + # If the trajectory uses elliptical sampling, ignore the k-space holes + # outside the ellipsoid. + data[ + np.linalg.norm( + np.meshgrid( + *[np.linspace(-1, 1, sh) for sh in shape], indexing="ij" + ), + axis=0, + ) + > 1 + ] = 0 + elif grid_type in ["gradients", "slew"]: + gradients, initial_position = convert_trajectory_to_gradients( + trajectory, + norm_factor=KMAX, + resolution=np.asarray(traj_params["FOV"]) + / np.asarray(traj_params["img_size"]), + raster_time=DEFAULT_RASTER_TIME, + gamma=traj_params["gamma"], + ) + if grid_type == "gradients": + data = np.hstack( + [gradients, np.zeros((gradients.shape[0], 1, gradients.shape[2]))] + ) + else: + slews, _ = convert_gradients_to_slew_rates(gradients, DEFAULT_RASTER_TIME) + data = np.hstack([slews, np.zeros((slews.shape[0], 2, slews.shape[2]))]) + data = grid_op.raw_op.adj_op( + np.linalg.norm(data, axis=-1).flatten(), None, True + ) + return np.squeeze(np.abs(data))