From 65a3e2e2b9715715956022368061375284546d5b Mon Sep 17 00:00:00 2001 From: Adrien Berchet Date: Wed, 17 Apr 2024 15:42:21 +0200 Subject: [PATCH] Add check for overlapping points in neurites (#1115) --- neurom/check/morphology_checks.py | 17 ++++++- neurom/check/morphtree.py | 72 +++++++++++++++++++++++++++ tests/check/test_morphology_checks.py | 30 +++++++++++ tests/check/test_morphtree.py | 7 +++ 4 files changed, 125 insertions(+), 1 deletion(-) diff --git a/neurom/check/morphology_checks.py b/neurom/check/morphology_checks.py index a5a6011f..ef2fc568 100644 --- a/neurom/check/morphology_checks.py +++ b/neurom/check/morphology_checks.py @@ -35,7 +35,7 @@ import numpy as np from neurom import NeuriteType from neurom.check import CheckResult -from neurom.check.morphtree import get_flat_neurites, back_tracking_segments +from neurom.check.morphtree import get_flat_neurites, back_tracking_segments, overlapping_points from neurom.core.morphology import Section, iter_neurites, iter_sections, iter_segments from neurom.core.dataformat import COLS from neurom.exceptions import NeuroMError @@ -365,3 +365,18 @@ def has_no_back_tracking(morph): for i in back_tracking_segments(neurite) ] return CheckResult(len(bad_ids) == 0, bad_ids) + + +def has_no_overlapping_point(morph, tolerance=None): + """Check if the morphology has overlapping points. + + Returns: + CheckResult with result. `result.info` contains a tuple with the two overlapping section ids + and a list containing only the first overlapping points. + """ + bad_ids = [ + (i[:2], np.atleast_2d(i[2])) + for neurite in iter_neurites(morph) + for i in overlapping_points(neurite, tolerance=tolerance) + ] + return CheckResult(len(bad_ids) == 0, bad_ids) diff --git a/neurom/check/morphtree.py b/neurom/check/morphtree.py index 293293e1..1f0abf17 100644 --- a/neurom/check/morphtree.py +++ b/neurom/check/morphtree.py @@ -29,6 +29,8 @@ """Python module of NeuroM to check morphology trees.""" import numpy as np +from scipy.spatial import KDTree + from neurom.core.dataformat import COLS from neurom import morphmath as mm from neurom.morphmath import principal_direction_extent @@ -204,6 +206,64 @@ def is_back_tracking(neurite): return False +def overlapping_points(neurite, tolerance=None): + """Return overlapping points of a neurite. + + Args: + neurite(Neurite): neurite to operate on + tolerance(float): the tolerance used to find overlapping points + + Returns: + A generator of tuples containing the IDs of the two intersecting sections and the + overlapping point. + """ + # Create an array containing all the points of the neurite with 1st and last points of each + # section deduplicated. This array has 4 columns: the section ID of the point and the + # XYZ coordinates. + # Note: The section ID is cast to float in this operation and cast back to int later. + section_pts = np.vstack( + [ + np.insert(neurite.root_node.points[0, :3], 0, neurite.root_node.id), + np.vstack( + [ + np.concatenate( + [np.ones((len(sec.points) - 1, 1)) * sec.id, sec.points[1:, :3]], + axis=1, + ) + for sec in neurite.iter_sections() + ], + ), + ], + ) + tree = KDTree(section_pts[:, 1:4]) + if tolerance is None: + tolerance = 0 + for pt_id1, pt_id2 in tree.query_pairs(tolerance): + yield ( + int(section_pts[pt_id1, 0]), # Cast the first section ID back to int + int(section_pts[pt_id2, 0]), # Cast the second section ID back to int + section_pts[pt_id1, 1:4], # The overlapping point of the first section + ) + + +def has_overlapping_points(neurite, tolerance=None): + """Check if a neurite has at least one overlapping point. + + See overlapping_points() for more details. + + Args: + neurite(Neurite): neurite to operate on + tolerance(float): the tolerance used to find overlapping points + + Returns: + True if two points of the neurite are overlapping. + """ + for _i in overlapping_points(neurite, tolerance=tolerance): + # If one overlapping point is found then the neurite is overlapping + return True + return False + + def get_flat_neurites(morph, tol=0.1, method='ratio'): """Check if a morphology has neurites that are flat within a tolerance. @@ -246,3 +306,15 @@ def get_back_tracking_neurites(morph): List of morphologies with backtracks """ return [n for n in morph.neurites if is_back_tracking(n)] + + +def get_overlapping_point_neurites(morph, tolerance=0): + """Get neurites that have overlapping points. + + Args: + morph(Morphology): neurite to operate on + + Returns: + List of morphologies with backtracks + """ + return [n for n in morph.neurites if has_overlapping_points(n, tolerance=tolerance)] diff --git a/tests/check/test_morphology_checks.py b/tests/check/test_morphology_checks.py index ad8defe2..a3d252a2 100644 --- a/tests/check/test_morphology_checks.py +++ b/tests/check/test_morphology_checks.py @@ -488,3 +488,33 @@ def test_has_no_back_tracking(): assert_array_equal(info[0][1], [[1, -3, 0]]) assert_array_equal(info[1][0], [2, 1, 1]) assert_array_equal(info[1][1], [[1, -3, 0]]) + + +def test_has_no_overlapping_point(): + m = load_morphology(""" + ((CellBody) (-1 0 0 2) (1 0 0 2)) + + ((Dendrite) + (0 0 0 0.4) + (0 1 0 0.3) + (0 2 0 0.28) + ( + (0 2 0 0.28) + (1 3 0 0.3) + (2 4 0 0.22) + | + (0 2 0 0.28) + (1 -3 0 0.3) + (2 -4 0 0.24) + (1 -3 0 0.52) + (0 1 0 0.2) + (4 -6 0 0.2) + )) +""", "asc") + result = morphology_checks.has_no_overlapping_point(m) + assert result.status is False + info = result.info + assert_array_equal(info[0][0], [0, 2]) + assert_array_equal(info[0][1], [[0, 1, 0]]) + assert_array_equal(info[1][0], [2, 2]) + assert_array_equal(info[1][1], [[1, -3, 0]]) diff --git a/tests/check/test_morphtree.py b/tests/check/test_morphtree.py index a851987a..a3e1ce2b 100644 --- a/tests/check/test_morphtree.py +++ b/tests/check/test_morphtree.py @@ -197,3 +197,10 @@ def test_get_nonmonotonic_neurites(): def test_get_back_tracking_neurites(): m = load_morphology(Path(SWC_PATH, 'Neuron.swc')) assert len(mt.get_back_tracking_neurites(m)) == 4 + + +def test_get_overlapping_point_neurites(): + m = load_morphology(Path(SWC_PATH, 'Neuron.swc')) + assert len(mt.get_overlapping_point_neurites(m)) == 0 + assert len(mt.get_overlapping_point_neurites(m, tolerance=0.09)) == 1 + assert len(mt.get_overlapping_point_neurites(m, tolerance=999)) == 4