From 9d14d3fda2f8bc0600b19fff18d07878da6ee0e5 Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Thu, 26 Dec 2024 11:49:50 -0500 Subject: [PATCH] ENH: better dps/dpp handling --- scilpy/tractograms/dps_and_dpp_management.py | 58 ++++++++---- .../tests/test_dps_and_dpp_management.py | 88 ++++++++++++++----- 2 files changed, 104 insertions(+), 42 deletions(-) diff --git a/scilpy/tractograms/dps_and_dpp_management.py b/scilpy/tractograms/dps_and_dpp_management.py index 21dfa0823..fda790b35 100644 --- a/scilpy/tractograms/dps_and_dpp_management.py +++ b/scilpy/tractograms/dps_and_dpp_management.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import numpy as np +from collections.abc import Iterable from nibabel.streamlines import ArraySequence @@ -23,18 +24,27 @@ def get_data_as_arraysequence(data, ref_sft): data_as_arraysequence: ArraySequence The data as an ArraySequence. """ - - if data.shape[0] == len(ref_sft): + # Check if data has the right shape, either one value per streamline or one + # value per point. + if data.shape[0] == ref_sft._get_streamline_count(): + # Two consective if statements to handle both 1D and 2D arrays + # and turn them into lists of lists of lists. + # Check if the data is a vector or a scalar. + if len(data.shape) == 1: + data = data[:, None] + # ArraySequence expects a list of lists of lists, so we need to add + # an extra dimension. + if len(data.shape) == 2: + data = data[:, None, :] data_as_arraysequence = ArraySequence(data) + elif data.shape[0] == ref_sft._get_point_count(): - data_as_arraysequence = ArraySequence() - # This function was created to avoid messing with _data, _offsets and - # _lengths, so this feel kind of bad. However, the other way would be - # to create a new ArraySequence and iterate over the streamlines, but - # that would be way slower. - data_as_arraysequence._data = data - data_as_arraysequence._offsets = ref_sft.streamlines._offsets - data_as_arraysequence._lengths = ref_sft.streamlines._lengths + # Split the data into a list of arrays, one per streamline. + # np.split takes the indices at which to split the array, so use + # np.cumsum to get the indices of the end of each streamline. + data_split = np.split(data, np.cumsum(ref_sft.streamlines._lengths)[:-1]) + # Create an ArraySequence from the list of arrays. + data_as_arraysequence = ArraySequence(data_split) else: raise ValueError("Data has the wrong shape. Expecting either one value" " per streamline ({}) or one per point ({}) but got " @@ -93,19 +103,31 @@ def add_data_as_color_dpp(sft, color): The upper bound of the associated colormap. """ - if color.total_nb_rows == len(sft): - tmp = [np.tile([color[i][0], color[i][1], color[i][2]], + if len(color) == sft._get_streamline_count(): + if color.common_shape != (3,): + raise ValueError("Colors do not have the right shape. Expecting " + "RBG values, but got values of shape {}.".format( + color.common_shape)) + + tmp = [np.tile([color[i][0][0], color[i][0][1], color[i][0][2]], (len(sft.streamlines[i]), 1)) for i in range(len(sft.streamlines))] sft.data_per_point['color'] = tmp - elif color.total_nb_rows == sft.streamlines.total_nb_rows: + + elif len(color) == sft._get_point_count(): + + if color.common_shape != (3,): + raise ValueError("Colors do not have the right shape. Expecting " + "RBG values, but got values of shape {}.".format( + color.common_shape)) + sft.data_per_point['color'] = color else: - raise ValueError("Error in the code... Colors do not have the right " - "shape. Expecting either one color per streamline " - "({}) or one per point ({}) but got {}." - .format(len(sft), len(sft.streamlines._data), - len(color))) + raise ValueError("Colors do not have the right shape. Expecting either" + " one color per streamline ({}) or one per point ({})" + " but got {}.".format(sft._get_streamline_count(), + sft._get_point_count(), + color.total_nb_rows)) return sft diff --git a/scilpy/tractograms/tests/test_dps_and_dpp_management.py b/scilpy/tractograms/tests/test_dps_and_dpp_management.py index c32f2eace..4d0ccd410 100644 --- a/scilpy/tractograms/tests/test_dps_and_dpp_management.py +++ b/scilpy/tractograms/tests/test_dps_and_dpp_management.py @@ -1,11 +1,14 @@ # -*- coding: utf-8 -*- import nibabel as nib import numpy as np +import pytest + from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin from scilpy.image.volume_space_management import DataVolume from scilpy.tests.utils import nan_array_equal from scilpy.tractograms.dps_and_dpp_management import ( + get_data_as_arraysequence, add_data_as_color_dpp, convert_dps_to_dpp, project_map_to_streamlines, project_dpp_to_map, perform_operation_on_dpp, perform_operation_dpp_to_dps, perform_correlation_on_endpoints) @@ -27,45 +30,82 @@ def _get_small_sft(): return fake_sft -def test_add_data_as_color_dpp(): - lut = get_lookup_table('viridis') +def test_get_data_as_arraysequence_dpp(): + fake_sft = _get_small_sft() + + some_data = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5]) + + # Test 1: One value per point. + array_seq = get_data_as_arraysequence(some_data, fake_sft) - # Important. cmap(1) != cmap(1.0) - lowest_color = np.asarray(lut(0.0)[0:3]) * 255 - highest_color = np.asarray(lut(1.0)[0:3]) * 255 + assert fake_sft._get_point_count() == array_seq.total_nb_rows + +def test_get_data_as_arraysequence_dps(): fake_sft = _get_small_sft() + some_data = np.asarray([2, 20]) + + # Test 1: One value per point. + array_seq = get_data_as_arraysequence(some_data, fake_sft) + assert fake_sft._get_streamline_count() == array_seq.total_nb_rows + + +def test_get_data_as_arraysequence_dps_2D(): + fake_sft = _get_small_sft() + + some_data = np.asarray([[2], [20]]) + + # Test 1: One value per point. + array_seq = get_data_as_arraysequence(some_data, fake_sft) + assert fake_sft._get_streamline_count() == array_seq.total_nb_rows + + +def test_get_data_as_arraysequence_error(): + fake_sft = _get_small_sft() + + some_data = np.asarray([2, 20, 200, 0.1]) + + # Test 1: One value per point. + with pytest.raises(ValueError): + _ = get_data_as_arraysequence(some_data, fake_sft) + + +def test_add_data_as_dpp_1_per_point(): + + fake_sft = _get_small_sft() + cmap = get_lookup_table('jet') + # Not testing the clipping options. Will be tested through viz.utils tests # Test 1: One value per point. # Lowest cmap color should be first point of second streamline. - some_data = [[2, 20, 200], [0.1, 0.3, 22, 5]] - colored_sft, lbound, ubound = add_data_as_color_dpp( - fake_sft, lut, some_data) + values = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5]) + color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8) + + array_seq = get_data_as_arraysequence(color, fake_sft) + colored_sft = add_data_as_color_dpp( + fake_sft, array_seq) assert len(colored_sft.data_per_streamline.keys()) == 0 assert list(colored_sft.data_per_point.keys()) == ['color'] - assert lbound == 0.1 - assert ubound == 200 - assert np.array_equal(colored_sft.data_per_point['color'][1][0, :], - lowest_color) - assert np.array_equal(colored_sft.data_per_point['color'][0][2, :], - highest_color) + + +def test_add_data_as_dpp_1_per_streamline(): + + fake_sft = _get_small_sft() + cmap = get_lookup_table('jet') # Test 2: One value per streamline # Lowest cmap color should be every point in first streamline - some_data = np.asarray([4, 5]) - colored_sft, lbound, ubound = add_data_as_color_dpp( - fake_sft, lut, some_data) + values = np.asarray([4, 5]) + color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8) + array_seq = get_data_as_arraysequence(color, fake_sft) + + colored_sft = add_data_as_color_dpp( + fake_sft, array_seq) + assert len(colored_sft.data_per_streamline.keys()) == 0 assert list(colored_sft.data_per_point.keys()) == ['color'] - assert lbound == 4 - assert ubound == 5 - # Lowest cmap color should be first point of second streamline. - # Same value for all points. - colors_first_line = colored_sft.data_per_point['color'][0] - assert np.array_equal(colors_first_line[0, :], lowest_color) - assert np.all(colors_first_line[1:, :] == colors_first_line[0, :]) def test_convert_dps_to_dpp():