Skip to content

Commit

Permalink
Connectivity labels: save rejected streamlines
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Feb 16, 2024
1 parent 493299e commit 899454f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 34 deletions.
44 changes: 27 additions & 17 deletions dwi_ml/data/processing/streamlines/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ def _compute_origin_finish_blocs(streamlines, volume_size, nb_blocs):


def compute_triu_connectivity_from_labels(streamlines, data_labels,
binary: bool = False,
use_scilpy=False):
"""
Compute a connectivity matrix.
Expand All @@ -313,8 +312,6 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels,
Streamlines, in vox space, corner origin.
data_labels: np.ndarray
The loaded nifti image.
binary: bool
If True, return a binary matrix.
use_scilpy: bool
If True, uses scilpy's method:
'Strategy is to keep the longest streamline segment
Expand Down Expand Up @@ -380,16 +377,14 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels,

start_labels.append(start)
end_labels.append(end)

matrix[start, end] += 1
if start != end:
matrix[end, start] += 1

matrix = np.triu(matrix)
assert matrix.sum() == len(streamlines)

if binary:
matrix = matrix.astype(bool)

return matrix, real_labels, start_labels, end_labels


Expand Down Expand Up @@ -463,9 +458,11 @@ def prepare_figure_connectivity(matrix):
axs[1, 1].imshow(matrix)
axs[1, 1].set_title("Binary")

plt.suptitle("All versions of the connectivity matrix.")


def find_streamlines_with_chosen_connectivity(
streamlines, label1, label2, start_labels, end_labels):
streamlines, start_labels, end_labels, label1, label2=None):
"""
Returns streamlines corresponding to a (label1, label2) or (label2, label1)
connection.
Expand All @@ -474,19 +471,32 @@ def find_streamlines_with_chosen_connectivity(
----------
streamlines: list of np arrays or list of tensors.
Streamlines, in vox space, corner origin.
label1: int
The bloc of interest, either as starting or finishing point.
label2: int
The bloc of interest, either as starting or finishing point.
start_labels: list[int]
The starting bloc for each streamline.
end_labels: list[int]
The ending bloc for each streamline.
label1: int
The bloc of interest, either as starting or finishing point.
label2: int, optional
The bloc of interest, either as starting or finishing point.
If label2 is None, then all connections (label1, Y) and (X, label1)
are found.
"""
start_labels = np.asarray(start_labels)
end_labels = np.asarray(end_labels)

str_ind1 = np.logical_and(start_labels == label1,
end_labels == label2)
str_ind2 = np.logical_and(start_labels == label2,
end_labels == label1)
str_ind = np.logical_or(str_ind1, str_ind2)
return [s for i, s in enumerate(streamlines) if str_ind[i]]
if label2 is None:
labels2 = np.unique(np.concatenate((start_labels[:], end_labels[:])))
else:
labels2 = [label2]

found = np.zeros(len(streamlines))
for label2 in labels2:
str_ind1 = np.logical_and(start_labels == label1,
end_labels == label2)
str_ind2 = np.logical_and(start_labels == label2,
end_labels == label1)
str_ind = np.logical_or(str_ind1, str_ind2)
found = np.logical_or(found, str_ind)

return [s for i, s in enumerate(streamlines) if found[i]]
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main():
i, j = np.unravel_index(np.argmax(matrix, axis=None), matrix.shape)
print("Saving biggest bundle: {} streamlines.".format(matrix[i, j]))
biggest = find_streamlines_with_chosen_connectivity(
in_sft.streamlines, i, j, start_blocs, end_blocs)
in_sft.streamlines, start_blocs, end_blocs, i, j)
sft = in_sft.from_sft(biggest, in_sft)
save_tractogram(sft, args.save_biggest)

Expand All @@ -108,7 +108,7 @@ def main():
i, j = np.unravel_index(tmp_matrix.argmin(axis=None), matrix.shape)
print("Saving smallest bundle: {} streamlines.".format(matrix[i, j]))
biggest = find_streamlines_with_chosen_connectivity(
in_sft.streamlines, i, j, start_blocs, end_blocs)
in_sft.streamlines, start_blocs, end_blocs, i, j)
sft = in_sft.from_sft(biggest, in_sft)
save_tractogram(sft, args.save_smallest)

Expand Down
35 changes: 20 additions & 15 deletions scripts_python/dwiml_compute_connectivity_matrix_from_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _build_arg_parser():
"streamline count is saved.")
p.add_argument('--show_now', action='store_true',
help="If set, shows the matrix with matplotlib.")
p.add_argument('--hide_background', nargs='?', const=0, type=float,
p.add_argument('--hide_background', nargs='?', const=0, type=int,
help="If true, set the connectivity matrix for chosen "
"label (default: 0), to 0.")
p.add_argument(
Expand Down Expand Up @@ -80,10 +80,11 @@ def main():
p.error("--out_file should have a .npy extension.")

out_fig = tmp + '.png'
out_fig_noback = tmp + '_hidden_background.png'
out_ordered_labels = tmp + '_labels.txt'
out_rejected_streamlines = tmp + '_rejected_from_background.trk'
assert_inputs_exist(p, [args.in_labels, args.streamlines])
assert_outputs_exist(p, args, [args.out_file, out_fig, out_fig_noback],
assert_outputs_exist(p, args,
[args.out_file, out_fig, out_rejected_streamlines],
[args.save_biggest, args.save_smallest])

ext = os.path.splitext(args.streamlines)[1]
Expand All @@ -101,32 +102,36 @@ def main():

in_sft.to_vox()
in_sft.to_corner()
matrix, ordered_labels, start_blocs, end_blocs = \
matrix, ordered_labels, start_labels, end_labels = \
compute_triu_connectivity_from_labels(
in_sft.streamlines, data_labels,
use_scilpy=args.use_longest_segment)

prepare_figure_connectivity(matrix)
plt.savefig(out_fig)

if args.hide_background is not None:
idx = ordered_labels.index(args.hide_background)
nb_hidden = np.sum(matrix[idx, :]) + np.sum(matrix[:, idx]) - \
matrix[idx, idx]
if nb_hidden > 0:
logging.info("CAREFUL! Deleting from the matrix {} streamlines "
"with one or both endpoints in a non-labelled area "
"(background = {}; line/column {})"
.format(nb_hidden, args.hide_background, idx))
logging.warning("CAREFUL! Deleting from the matrix {} streamlines "
"with one or both endpoints in a non-labelled "
"area (background = {}; line/column {})"
.format(nb_hidden, args.hide_background, idx))
rejected = find_streamlines_with_chosen_connectivity(
in_sft.streamlines, start_labels, end_labels, idx)
logging.info("Saving rejected streamlines in {}"
.format(out_rejected_streamlines))
sft = in_sft.from_sft(rejected, in_sft)
save_tractogram(sft, out_rejected_streamlines)
else:
logging.info("No streamlines with endpoints in the background :)")
matrix[idx, :] = 0
matrix[:, idx] = 0
ordered_labels[idx] = ("Hidden background ({})"
.format(args.hide_background))

prepare_figure_connectivity(matrix)
plt.savefig(out_fig_noback)
# Save figure will all versions of the matrix.
prepare_figure_connectivity(matrix)
plt.savefig(out_fig)

if args.binary:
matrix = matrix > 0
Expand All @@ -143,7 +148,7 @@ def main():
.format(matrix[i, j], ordered_labels[i], ordered_labels[j],
i, j))
biggest = find_streamlines_with_chosen_connectivity(
in_sft.streamlines, i, j, start_blocs, end_blocs)
in_sft.streamlines, i, j, start_labels, end_labels)
sft = in_sft.from_sft(biggest, in_sft)
save_tractogram(sft, args.save_biggest)

Expand All @@ -155,7 +160,7 @@ def main():
.format(matrix[i, j], ordered_labels[i], ordered_labels[j],
i, j))
smallest = find_streamlines_with_chosen_connectivity(
in_sft.streamlines, i, j, start_blocs, end_blocs)
in_sft.streamlines, i, j, start_labels, end_labels)
sft = in_sft.from_sft(smallest, in_sft)
save_tractogram(sft, args.save_smallest)

Expand Down

0 comments on commit 899454f

Please sign in to comment.