Skip to content

Commit

Permalink
ENH: better dps/dpp handling
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoineTheb committed Dec 26, 2024
1 parent 4fbfdc9 commit 9d14d3f
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 42 deletions.
58 changes: 40 additions & 18 deletions scilpy/tractograms/dps_and_dpp_management.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import numpy as np

from collections.abc import Iterable
from nibabel.streamlines import ArraySequence


Expand All @@ -23,18 +24,27 @@ def get_data_as_arraysequence(data, ref_sft):
data_as_arraysequence: ArraySequence
The data as an ArraySequence.
"""

if data.shape[0] == len(ref_sft):
# Check if data has the right shape, either one value per streamline or one
# value per point.
if data.shape[0] == ref_sft._get_streamline_count():
# Two consective if statements to handle both 1D and 2D arrays
# and turn them into lists of lists of lists.
# Check if the data is a vector or a scalar.
if len(data.shape) == 1:
data = data[:, None]
# ArraySequence expects a list of lists of lists, so we need to add
# an extra dimension.
if len(data.shape) == 2:
data = data[:, None, :]
data_as_arraysequence = ArraySequence(data)

elif data.shape[0] == ref_sft._get_point_count():
data_as_arraysequence = ArraySequence()
# This function was created to avoid messing with _data, _offsets and
# _lengths, so this feel kind of bad. However, the other way would be
# to create a new ArraySequence and iterate over the streamlines, but
# that would be way slower.
data_as_arraysequence._data = data
data_as_arraysequence._offsets = ref_sft.streamlines._offsets
data_as_arraysequence._lengths = ref_sft.streamlines._lengths
# Split the data into a list of arrays, one per streamline.
# np.split takes the indices at which to split the array, so use
# np.cumsum to get the indices of the end of each streamline.
data_split = np.split(data, np.cumsum(ref_sft.streamlines._lengths)[:-1])
# Create an ArraySequence from the list of arrays.
data_as_arraysequence = ArraySequence(data_split)
else:
raise ValueError("Data has the wrong shape. Expecting either one value"
" per streamline ({}) or one per point ({}) but got "
Expand Down Expand Up @@ -93,19 +103,31 @@ def add_data_as_color_dpp(sft, color):
The upper bound of the associated colormap.
"""

if color.total_nb_rows == len(sft):
tmp = [np.tile([color[i][0], color[i][1], color[i][2]],
if len(color) == sft._get_streamline_count():
if color.common_shape != (3,):
raise ValueError("Colors do not have the right shape. Expecting "
"RBG values, but got values of shape {}.".format(
color.common_shape))

tmp = [np.tile([color[i][0][0], color[i][0][1], color[i][0][2]],
(len(sft.streamlines[i]), 1))
for i in range(len(sft.streamlines))]
sft.data_per_point['color'] = tmp
elif color.total_nb_rows == sft.streamlines.total_nb_rows:

elif len(color) == sft._get_point_count():

if color.common_shape != (3,):
raise ValueError("Colors do not have the right shape. Expecting "
"RBG values, but got values of shape {}.".format(
color.common_shape))

sft.data_per_point['color'] = color
else:
raise ValueError("Error in the code... Colors do not have the right "
"shape. Expecting either one color per streamline "
"({}) or one per point ({}) but got {}."
.format(len(sft), len(sft.streamlines._data),
len(color)))
raise ValueError("Colors do not have the right shape. Expecting either"
" one color per streamline ({}) or one per point ({})"
" but got {}.".format(sft._get_streamline_count(),
sft._get_point_count(),
color.total_nb_rows))
return sft


Expand Down
88 changes: 64 additions & 24 deletions scilpy/tractograms/tests/test_dps_and_dpp_management.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# -*- coding: utf-8 -*-
import nibabel as nib
import numpy as np
import pytest

from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin

from scilpy.image.volume_space_management import DataVolume
from scilpy.tests.utils import nan_array_equal
from scilpy.tractograms.dps_and_dpp_management import (
get_data_as_arraysequence,
add_data_as_color_dpp, convert_dps_to_dpp, project_map_to_streamlines,
project_dpp_to_map, perform_operation_on_dpp, perform_operation_dpp_to_dps,
perform_correlation_on_endpoints)
Expand All @@ -27,45 +30,82 @@ def _get_small_sft():
return fake_sft


def test_add_data_as_color_dpp():
lut = get_lookup_table('viridis')
def test_get_data_as_arraysequence_dpp():
fake_sft = _get_small_sft()

some_data = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5])

# Test 1: One value per point.
array_seq = get_data_as_arraysequence(some_data, fake_sft)

# Important. cmap(1) != cmap(1.0)
lowest_color = np.asarray(lut(0.0)[0:3]) * 255
highest_color = np.asarray(lut(1.0)[0:3]) * 255
assert fake_sft._get_point_count() == array_seq.total_nb_rows


def test_get_data_as_arraysequence_dps():
fake_sft = _get_small_sft()

some_data = np.asarray([2, 20])

# Test 1: One value per point.
array_seq = get_data_as_arraysequence(some_data, fake_sft)
assert fake_sft._get_streamline_count() == array_seq.total_nb_rows


def test_get_data_as_arraysequence_dps_2D():
fake_sft = _get_small_sft()

some_data = np.asarray([[2], [20]])

# Test 1: One value per point.
array_seq = get_data_as_arraysequence(some_data, fake_sft)
assert fake_sft._get_streamline_count() == array_seq.total_nb_rows


def test_get_data_as_arraysequence_error():
fake_sft = _get_small_sft()

some_data = np.asarray([2, 20, 200, 0.1])

# Test 1: One value per point.
with pytest.raises(ValueError):
_ = get_data_as_arraysequence(some_data, fake_sft)


def test_add_data_as_dpp_1_per_point():

fake_sft = _get_small_sft()
cmap = get_lookup_table('jet')

# Not testing the clipping options. Will be tested through viz.utils tests

# Test 1: One value per point.
# Lowest cmap color should be first point of second streamline.
some_data = [[2, 20, 200], [0.1, 0.3, 22, 5]]
colored_sft, lbound, ubound = add_data_as_color_dpp(
fake_sft, lut, some_data)
values = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5])
color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8)

array_seq = get_data_as_arraysequence(color, fake_sft)
colored_sft = add_data_as_color_dpp(
fake_sft, array_seq)
assert len(colored_sft.data_per_streamline.keys()) == 0
assert list(colored_sft.data_per_point.keys()) == ['color']
assert lbound == 0.1
assert ubound == 200
assert np.array_equal(colored_sft.data_per_point['color'][1][0, :],
lowest_color)
assert np.array_equal(colored_sft.data_per_point['color'][0][2, :],
highest_color)


def test_add_data_as_dpp_1_per_streamline():

fake_sft = _get_small_sft()
cmap = get_lookup_table('jet')

# Test 2: One value per streamline
# Lowest cmap color should be every point in first streamline
some_data = np.asarray([4, 5])
colored_sft, lbound, ubound = add_data_as_color_dpp(
fake_sft, lut, some_data)
values = np.asarray([4, 5])
color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8)
array_seq = get_data_as_arraysequence(color, fake_sft)

colored_sft = add_data_as_color_dpp(
fake_sft, array_seq)

assert len(colored_sft.data_per_streamline.keys()) == 0
assert list(colored_sft.data_per_point.keys()) == ['color']
assert lbound == 4
assert ubound == 5
# Lowest cmap color should be first point of second streamline.
# Same value for all points.
colors_first_line = colored_sft.data_per_point['color'][0]
assert np.array_equal(colors_first_line[0, :], lowest_color)
assert np.all(colors_first_line[1:, :] == colors_first_line[0, :])


def test_convert_dps_to_dpp():
Expand Down

0 comments on commit 9d14d3f

Please sign in to comment.