Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
ENH: Add helper functions for DWI signal value visualization
Browse files Browse the repository at this point in the history
Add helper functions for DWI signal value visualization.
  • Loading branch information
jhlegarreta committed Oct 24, 2024
1 parent 796c501 commit e4d2de8
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions src/eddymotion/viz/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import matplotlib.gridspec as gridspec
import numpy as np
from matplotlib import pyplot as plt
from scipy.spatial import ConvexHull, KDTree
from scipy.stats import pearsonr


Expand Down Expand Up @@ -112,3 +113,96 @@ def plot_correlation(x, y, title):
fig.tight_layout()

return fig, r


def calculate_sphere_pts(points, center):
"""Calculate the location of each point when it is expanded out to the sphere."""

kdtree = KDTree(points) # tree of nearest points
# d is an array of distances, i is an array of indices
d, i = kdtree.query(center, points.shape[0])
sphere_pts = np.zeros(points.shape, dtype=float)

radius = np.amax(d)
for p in range(points.shape[0]):
sphere_pts[p] = points[i[p]] * radius / d[p]
# points and the indices for where they were in the original lists
return sphere_pts, i


def compute_dmri_convex_hull(s, dirs, mask=None):
"""Compute the convex hull of the dMRI signal s."""

if mask is None:
mask = np.ones(len(dirs), dtype=bool)

# Scale the original sampling directions by the corresponding signal values
scaled_bvecs = dirs[mask] * np.asarray(s)[:, np.newaxis]

# Create the data for the convex hull: project the scaled vectors to a
# sphere
sphere_pts, sphere_idx = calculate_sphere_pts(scaled_bvecs, [0, 0, 0])

# Create the convex hull: find the right ordering of vertices for the
# triangles: ConvexHull finds the simplices of the points on the outside of
# the data set
hull = ConvexHull(sphere_pts)
triang_idx = hull.simplices # returns the list of indices for each triangle

return scaled_bvecs, sphere_idx, triang_idx


def plot_surface(scaled_vecs, sphere_idx, triang_idx, title, cmap):
"""Plot a surface."""

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")

ax.scatter3D(
scaled_vecs[:, 0], scaled_vecs[:, 1], scaled_vecs[:, 2], s=2, c="black", alpha=1.0
)

surface = ax.plot_trisurf(
scaled_vecs[sphere_idx, 0],
scaled_vecs[sphere_idx, 1],
scaled_vecs[sphere_idx, 2],
triangles=triang_idx,
cmap=cmap,
alpha=0.6,
)

ax.view_init(10, 45)
ax.set_aspect("equal", adjustable="box")
ax.set_title(title)

return fig, ax, surface


def plot_signal_data(y, ax):
"""Plot the data provided as a scatter plot"""

ax.scatter(
y[:, 0], y[:, 1], y[:, 2], color="red", marker="*", alpha=0.8, s=5, label="Original points"
)


def plot_prediction_surface(y, y_pred, S0, y_dirs, y_pred_dirs, title, cmap):
"""Plot the prediction surface obtained by computing the convex hull of the
predicted signal data, and plot the true data as a scatter plot."""

# Scale the original sampling directions by the corresponding signal values
y_bvecs = y_dirs * np.asarray(y)[:, np.newaxis]

# Compute the convex hull
y_pred_bvecs, sphere_idx, triang_idx = compute_dmri_convex_hull(y_pred, y_pred_dirs)

# Plot the surface
fig, ax, surface = plot_surface(y_pred_bvecs, sphere_idx, triang_idx, title, cmap)

# Add the underlying signal to the plot
# plot_signal_data(y_bvecs/S0, ax)
plot_signal_data(y_bvecs, ax)

fig.tight_layout()

return fig, ax, surface

0 comments on commit e4d2de8

Please sign in to comment.