diff --git a/scilpy/tractograms/streamline_operations.py b/scilpy/tractograms/streamline_operations.py index 5249b80a8..6e244af92 100644 --- a/scilpy/tractograms/streamline_operations.py +++ b/scilpy/tractograms/streamline_operations.py @@ -408,11 +408,8 @@ def filter_streamlines_by_length(sft, min_length=0., max_length=np.inf, valid_length_ids = np.logical_and(lengths >= min_length, lengths <= max_length) filtered_sft = sft[valid_length_ids] - - if return_rejected: - rejected_sft = sft[~valid_length_ids] else: - valid_length_ids = [] + valid_length_ids = np.array([], dtype=bool) filtered_sft = sft # Return to original space @@ -420,7 +417,7 @@ def filter_streamlines_by_length(sft, min_length=0., max_length=np.inf, filtered_sft.to_space(orig_space) if return_rejected: - rejected_sft.to_space(orig_space) + rejected_sft = sft[~valid_length_ids] return filtered_sft, valid_length_ids, rejected_sft else: return filtered_sft, valid_length_ids diff --git a/scilpy/tractograms/tests/test_streamline_operations.py b/scilpy/tractograms/tests/test_streamline_operations.py index ccddcc972..d381da976 100644 --- a/scilpy/tractograms/tests/test_streamline_operations.py +++ b/scilpy/tractograms/tests/test_streamline_operations.py @@ -7,6 +7,7 @@ import pytest from dipy.io.streamline import load_tractogram from dipy.tracking.streamlinespeed import length +from dipy.io.stateful_tractogram import StatefulTractogram from scilpy import SCILPY_HOME from scilpy.io.fetcher import fetch_data, get_testing_files_dict @@ -174,6 +175,17 @@ def test_filter_streamlines_by_length(): # Test that streamlines shorter than 100 and longer than 120 were removed. assert np.all(lengths >= min_length) and np.all(lengths <= max_length) + # === 4. Return rejected streamlines with empty sft === + empty_sft = short_sft[[]] # Empty sft from short_sft (chosen arbitrarily) + filtered_sft, _, rejected = \ + filter_streamlines_by_length(empty_sft, min_length=min_length, + max_length=max_length, + return_rejected=True) + assert isinstance(filtered_sft, StatefulTractogram) + assert isinstance(rejected, StatefulTractogram) + assert len(filtered_sft) == 0 + assert len(rejected) == 0 + def test_filter_streamlines_by_total_length_per_dim(): long_sft = load_tractogram(in_long_sft, in_ref)