Skip to content

Commit

Permalink
Delete create_simple_connectivity_matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Oct 31, 2024
1 parent ce0a2d7 commit 3ec6d1e
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 388 deletions.
164 changes: 1 addition & 163 deletions dwi_ml/data/processing/streamlines/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,100 +301,9 @@ def _compute_origin_finish_blocs(streamlines, volume_size, nb_blocs):
return start_block, end_block


def compute_triu_connectivity_from_labels(streamlines, data_labels,
use_scilpy=False):
"""
Compute a connectivity matrix.
Parameters
----------
streamlines: list of np arrays or list of tensors.
Streamlines, in vox space, corner origin.
data_labels: np.ndarray
The loaded nifti image.
use_scilpy: bool
If True, uses scilpy's method:
'Strategy is to keep the longest streamline segment
connecting 2 regions. If the streamline crosses other gray
matter regions before reaching its final connected region,
the kept connection is still the longest. This is robust to
compressed streamlines.'
Else, uses simple computation from endpoints. Faster. Also, works with
incomplete parcellation.
Returns
-------
matrix: np.ndarray
With use_scilpy: shape (nb_labels + 1, nb_labels + 1)
(last label is "Not Found")
Else, shape (nb_labels, nb_labels)
labels: List
The list of labels
start_labels: List
For each streamline, the label at starting point.
end_labels: List
For each streamline, the label at ending point.
"""
real_labels = list(np.sort(np.unique(data_labels)))
nb_labels = len(real_labels)
logging.debug("Computing connectivity matrix for {} labels."
.format(nb_labels))

if use_scilpy:
matrix = np.zeros((nb_labels + 1, nb_labels + 1), dtype=int)
else:
matrix = np.zeros((nb_labels, nb_labels), dtype=int)

start_labels = []
end_labels = []

if use_scilpy:
indices, points_to_idx = uncompress(streamlines, return_mapping=True)

for strl_vox_indices in indices:
segments_info = segmenting_func(strl_vox_indices, data_labels)
if len(segments_info) > 0:
start = real_labels.index(segments_info[0]['start_label'])
end = real_labels.index(segments_info[0]['end_label'])
else:
start = nb_labels
end = nb_labels

start_labels.append(start)
end_labels.append(end)

matrix[start, end] += 1
if start != end:
matrix[end, start] += 1

real_labels = real_labels + [np.NaN]

else:
for s in streamlines:
# Vox space, corner origin
# = we can get the nearest neighbor easily.
# Coord 0 = voxel 0. Coord 0.9 = voxel 0. Coord 1 = voxel 1.
start = real_labels.index(
data_labels[tuple(np.floor(s[0, :]).astype(int))])
end = real_labels.index(
data_labels[tuple(np.floor(s[-1, :]).astype(int))])

start_labels.append(start)
end_labels.append(end)

matrix[start, end] += 1
if start != end:
matrix[end, start] += 1

matrix = np.triu(matrix)
assert matrix.sum() == len(streamlines)

return matrix, real_labels, start_labels, end_labels


def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs):
"""
Compute a connectivity matrix.
Compute a connectivity matrix using blocs of the volume instead of labels.
Parameters
----------
Expand Down Expand Up @@ -427,74 +336,3 @@ def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs):
assert matrix.sum() == len(streamlines)

return matrix, start_block, end_block


def prepare_figure_connectivity(matrix):
matrix = np.copy(matrix)

fig, axs = plt.subplots(2, 2)
im = axs[0, 0].imshow(matrix)
divider = make_axes_locatable(axs[0, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical')
axs[0, 0].set_title("Raw streamline count")

im = axs[0, 1].imshow(matrix + np.min(matrix[matrix > 0]), norm=LogNorm())
divider = make_axes_locatable(axs[0, 1])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical')
axs[0, 1].set_title("Raw streamline count (log view)")

matrix = matrix / matrix.sum() * 100
im = axs[1, 0].imshow(matrix)
divider = make_axes_locatable(axs[1, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical')
axs[1, 0].set_title("Percentage")

matrix = matrix > 0
axs[1, 1].imshow(matrix)
axs[1, 1].set_title("Binary")

plt.suptitle("All versions of the connectivity matrix.")


def find_streamlines_with_chosen_connectivity(
streamlines, start_labels, end_labels, label1, label2=None):
"""
Returns streamlines corresponding to a (label1, label2) or (label2, label1)
connection.
Parameters
----------
streamlines: list of np arrays or list of tensors.
Streamlines, in vox space, corner origin.
start_labels: list[int]
The starting bloc for each streamline.
end_labels: list[int]
The ending bloc for each streamline.
label1: int
The bloc of interest, either as starting or finishing point.
label2: int, optional
The bloc of interest, either as starting or finishing point.
If label2 is None, then all connections (label1, Y) and (X, label1)
are found.
"""
start_labels = np.asarray(start_labels)
end_labels = np.asarray(end_labels)

if label2 is None:
labels2 = np.unique(np.concatenate((start_labels[:], end_labels[:])))
else:
labels2 = [label2]

found = np.zeros(len(streamlines))
for label2 in labels2:
str_ind1 = np.logical_and(start_labels == label1,
end_labels == label2)
str_ind2 = np.logical_and(start_labels == label2,
end_labels == label1)
str_ind = np.logical_or(str_ind1, str_ind2)
found = np.logical_or(found, str_ind)

return [s for i, s in enumerate(streamlines) if found[i]]
8 changes: 5 additions & 3 deletions dwi_ml/training/trainers_withGV.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@
import torch
from torch.nn import PairwiseDistance

from scilpy.connectivity.connectivity import \
compute_triu_connectivity_from_labels

from dwi_ml.data.processing.streamlines.post_processing import \
compute_triu_connectivity_from_blocs, compute_triu_connectivity_from_labels
compute_triu_connectivity_from_blocs
from dwi_ml.models.main_models import ModelWithDirectionGetter
from dwi_ml.tracking.propagation import propagate_multiple_lines
from dwi_ml.tracking.io_utils import prepare_tracking_mask
Expand Down Expand Up @@ -356,8 +359,7 @@ def _compare_connectivity(self, lines, ids_per_subj):
else:
# Note: scilpy usage not ready! Simple endpoints position
batch_matrix, _, _, _ =\
compute_triu_connectivity_from_labels(
_lines, labels, use_scilpy=False)
compute_triu_connectivity_from_labels(_lines, labels)

# Where our batch has a 0: not important, maybe it was simply
# not in this batch.
Expand Down
72 changes: 44 additions & 28 deletions scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
This script creates a connectivity matrix (using streamline count, or binary)
when you don't have labels for your data, using a division of the volume into
N blocs. Useful for supervised machine learning models.
If you do have labels, see
>> scil_connectivity_compute_simple_matrix.py
"""
import argparse
import logging
import os.path

import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
from dipy.io.streamline import save_tractogram
from dipy.io.utils import is_header_compatible
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, \
add_verbose_arg, add_overwrite_arg

from dwi_ml.data.hdf5.utils import format_nb_blocs_connectivity
from dwi_ml.data.processing.streamlines.post_processing import \
compute_triu_connectivity_from_blocs, \
find_streamlines_with_chosen_connectivity, prepare_figure_connectivity
compute_triu_connectivity_from_blocs


def _build_arg_parser():
Expand All @@ -42,19 +52,43 @@ def _build_arg_parser():
p.add_argument('--show_now', action='store_true',
help="If set, shows the matrix with matplotlib.")

g = p.add_argument_group("Investigation of the matrix:")
g.add_argument('--save_biggest', metavar='filename',
help="If set, saves the biggest bundle (as tck or trk).")
g.add_argument('--save_smallest', metavar='filename',
help="If set, saves the smallest (non-zero) bundle "
"(as tck or trk).")

add_verbose_arg(p)
add_overwrite_arg(p)

return p


def prepare_figure_connectivity(matrix):
# Equivalent to the figure in scil_connectivity_compute_simple_matrix.py
matrix = np.copy(matrix)

fig, axs = plt.subplots(2, 2)
im = axs[0, 0].imshow(matrix)
divider = make_axes_locatable(axs[0, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical')
axs[0, 0].set_title("Raw streamline count")

im = axs[0, 1].imshow(matrix + np.min(matrix[matrix > 0]), norm=LogNorm())
divider = make_axes_locatable(axs[0, 1])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical')
axs[0, 1].set_title("Raw streamline count (log view)")

matrix = matrix / matrix.sum() * 100
im = axs[1, 0].imshow(matrix)
divider = make_axes_locatable(axs[1, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical')
axs[1, 0].set_title("Percentage of the total streamline count")

matrix = matrix > 0
axs[1, 1].imshow(matrix)
axs[1, 1].set_title("Binary matrix: 1 if at least 1 streamline")

plt.suptitle("Connectivity matrix: streamline count")


def main():
p = _build_arg_parser()
args = p.parse_args()
Expand Down Expand Up @@ -95,24 +129,6 @@ def main():
np.save(args.out_file, matrix)
plt.savefig(out_fig)

# Options to try to investigate the connectivity matrix:
if args.save_biggest is not None:
i, j = np.unravel_index(np.argmax(matrix, axis=None), matrix.shape)
print("Saving biggest bundle: {} streamlines.".format(matrix[i, j]))
biggest = find_streamlines_with_chosen_connectivity(
in_sft.streamlines, start_blocs, end_blocs, i, j)
sft = in_sft.from_sft(biggest, in_sft)
save_tractogram(sft, args.save_biggest)

if args.save_smallest is not None:
tmp_matrix = np.ma.masked_equal(matrix, 0)
i, j = np.unravel_index(tmp_matrix.argmin(axis=None), matrix.shape)
print("Saving smallest bundle: {} streamlines.".format(matrix[i, j]))
biggest = find_streamlines_with_chosen_connectivity(
in_sft.streamlines, start_blocs, end_blocs, i, j)
sft = in_sft.from_sft(biggest, in_sft)
save_tractogram(sft, args.save_smallest)

if args.show_now:
plt.show()

Expand Down
Loading

0 comments on commit 3ec6d1e

Please sign in to comment.