Skip to content

Commit

Permalink
add path pruning, __array__ method, and longest shortest branch detec…
Browse files Browse the repository at this point in the history
…tion to Skeleton (#117)
  • Loading branch information
kevinyamauchi authored Aug 3, 2021
1 parent 97a217d commit ac04540
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 2 deletions.
44 changes: 44 additions & 0 deletions skan/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from scipy import sparse, ndimage as ndi
from scipy.sparse import csgraph
from scipy import spatial
from skimage import morphology
import numba

from .nputil import raveled_steps_to_neighbors
from .summary_utils import find_main_branches


## NBGraph and Numba-based implementation
Expand Down Expand Up @@ -298,7 +300,9 @@ class Skeleton:
faster graph methods. For example, it is much faster to get a list of
neighbors, or test for the presence of a specific edge.
coordinates : array, shape (N, ndim)
skeleton_pixel_id i -> coordinates[i]
The image coordinates of each pixel in the skeleton.
Some values in this matrix are non-sensical — you should only access them from node ids.
paths : scipy.sparse.csr_matrix, shape (P, N + 1)
A csr_matrix where element [i, j] is on if node j is in path i. This
includes path endpoints. The number of nonzero elements is N - J + Sd.
Expand Down Expand Up @@ -341,6 +345,7 @@ def __init__(self, skeleton_image, *, spacing=1, source_image=None,
self.degrees = np.diff(self.graph.indptr)
self.spacing = (np.asarray(spacing) if not np.isscalar(spacing)
else np.full(skeleton_image.ndim, spacing))
self.unique_junctions = unique_junctions
if keep_images:
self.skeleton_image = skeleton_image
self.source_image = source_image
Expand Down Expand Up @@ -429,6 +434,22 @@ def paths_list(self):
"""
return [list(self.path(i)) for i in range(self.n_paths)]

def path_label_image(self):
"""Image like self.skeleton_image with path_ids as values.
Returns
-------
label_image : array of ints
Image of the same shape as self.skeleton_image where each pixel
has the value of its branch id + 1.
"""
image_out = np.zeros(self.skeleton_image.shape, dtype=int)
for i in range(self.n_paths):
coords_to_wipe = self.path_coordinates(i)
coords_idxs = tuple(np.round(coords_to_wipe).astype(int).T)
image_out[coords_idxs] = i + 1
return image_out

def path_means(self):
"""Compute the mean pixel value along each path.
Expand All @@ -455,6 +476,27 @@ def path_stdev(self):
means = self.path_means()
return np.sqrt(np.clip(sumsq/lengths - means*means, 0, None))

def prune_paths(self, indices) -> 'Skeleton':
# warning: slow
image_cp = np.copy(self.skeleton_image)
for i in indices:
coords_to_wipe = self.path_coordinates(i)
coords_idxs = tuple(np.round(coords_to_wipe).astype(int).T)
image_cp[coords_idxs] = 0
# optional cleanup:
new_skeleton = morphology.skeletonize(image_cp.astype(bool)) * image_cp
# note: add unique_junctions attribute for this
return Skeleton(
new_skeleton,
spacing=self.spacing,
source_image=self.source_image,
unique_junctions=self.unique_junctions,
)

def __array__(self, dtype=None):
"""Array representation of the skeleton path labels."""
return self.path_label_image()


def summarize(skel: Skeleton):
"""Compute statistics for every skeleton and branch in ``skel``.
Expand Down Expand Up @@ -503,6 +545,8 @@ def summarize(skel: Skeleton):
np.sqrt((coords_real_dst - coords_real_src)**2 @ np.ones(ndim))
)
df = pd.DataFrame(summary)
# define main branch as longest shortest path within a single skeleton
df['main'] = find_main_branches(df)
return df


Expand Down
3 changes: 1 addition & 2 deletions skan/nputil.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import collections
from collections import abc
import itertools
import numpy as np

Expand Down Expand Up @@ -67,7 +67,6 @@ def smallest_int_dtype(number, *, signed=False, min_dtype=np.int8):
dtype = min_dtype
return dtype


def raveled_steps_to_neighbors(shape, connectivity=1, *, order='C', spacing=1,
return_distances=True):
"""Return raveled coordinate steps for given array shape and neighborhood.
Expand Down
63 changes: 63 additions & 0 deletions skan/summary_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import networkx as nx
import numpy as np
from pandas import DataFrame
import toolz as tz


def find_main_branches(summary: DataFrame) -> np.ndarray:
"""Predict the extent of branching.
Parameters
----------
summary : pd.DataFrame
The summary table of the skeleton to analyze.
This must contain: ['node-id-src', 'node-id-dst', 'branch-distance']
Returns
-------
is_main: array
True if the index-matched path is the longest shortest path of the
skeleton
"""
is_main = np.zeros(summary.shape[0], dtype=bool)
us = summary['node-id-src']
vs = summary['node-id-dst']
ws = summary['branch-distance']

edge2idx = {
(u, v): i
for i, (u, v) in enumerate(zip(us, vs))
}

edge2idx.update({
(v, u): i
for i, (u, v) in enumerate(zip(us, vs))
})

g = nx.Graph()

g.add_weighted_edges_from(
zip(us, vs, ws)
)

for conn in nx.connected_components(g):
curr_val = 0
curr_pair = None
h = g.subgraph(conn)
p = dict(nx.all_pairs_dijkstra_path_length(h))
for src in p:
for dst in p[src]:
val = p[src][dst]
if (val is not None
and np.isfinite(val)
and val > curr_val):
curr_val = val
curr_pair = (src, dst)
for i, j in tz.sliding_window(
2,
nx.shortest_path(
h, source=curr_pair[0], target=curr_pair[1], weight='weight'
)
):
is_main[edge2idx[(i, j)]] = 1

return is_main
21 changes: 21 additions & 0 deletions skan/test/test_prune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

from skan._testdata import skeleton0
from skan import Skeleton


@pytest.mark.parametrize('branch_num', [0])
def test_pruning(branch_num):
skeleton = Skeleton(skeleton0)
pruned = skeleton.prune_paths([branch_num])
print(pruned.skeleton_image.astype(int))
assert pruned.n_paths == 1


@pytest.mark.xfail
@pytest.mark.parametrize('branch_num', [0, 1, 2])
def test_pruning_comprehensive(branch_num):
skeleton = Skeleton(skeleton0)
pruned = skeleton.prune_paths([branch_num])
print(pruned.skeleton_image.astype(int))
assert pruned.n_paths == 1
43 changes: 43 additions & 0 deletions skan/test/test_skeleton_class.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from time import process_time
import numpy as np
from numpy.testing import assert_equal, assert_allclose
import pytest
from skan.csr import Skeleton, summarize

from skan._testdata import (tinycycle, tinyline, skeleton0, skeleton1,
Expand Down Expand Up @@ -100,3 +101,45 @@ def test_skeleton_summarize():
assert set(summary['skeleton-id']) == {1, 2}
assert (np.all(summary['mean-pixel-value'] < 2)
and np.all(summary['mean-pixel-value'] > 1))


@pytest.mark.xfail
def test_skeleton_label_image_strict():
"""Test that the skeleton image has the same pattern as the expected label image.
This does pixel-wise pairing of the label image with the expected label image.
There should be the same number of unique pairs as there are unique labels
in the expected label image. This that the branches are displayed correctly,
but does not assert an order to the numbering of the branches.
This is expected to fail due to the current junction representation.
See: https://github.com/jni/skan/issues/133
"""
skeleton = Skeleton(skeleton4, unique_junctions=False)
label_image = np.asarray(skeleton)
expected = np.array([
[1, 0, 0, 0, 0],
[0, 1, 2, 2, 2],
[0, 3, 0, 0, 0],
[0, 3, 0, 0, 0],
])
paired_values = np.stack((expected, label_image), axis=-1).reshape((label_image.size, 2))
unique_pairs = np.unique(paired_values, axis=0)
expected_label_values = np.unique(expected)
assert len(expected_label_values) == len(unique_pairs)


def test_skeleton_label_image():
"""Simple test that the skeleton label image covers the same
pixels as the expected label image.
"""
skeleton = Skeleton(skeleton4, unique_junctions=False)
label_image = np.asarray(skeleton)
expected = np.array([
[1, 0, 0, 0, 0],
[0, 1, 2, 2, 2],
[0, 3, 0, 0, 0],
[0, 3, 0, 0, 0],
])

np.testing.assert_array_equal(label_image.astype(bool), expected.astype(bool))
19 changes: 19 additions & 0 deletions skan/test/test_summary_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np

from skan._testdata import skeleton1
from skan import Skeleton, summarize

def test_find_main():
skeleton = Skeleton(skeleton1)
summary_df = summarize(skeleton)

non_main_edge_start = [2, 1]
non_main_edge_finish = [3, 3]

non_main_df = summary_df.loc[summary_df['main'] == False]
assert non_main_df.shape[0] == 1
coords = non_main_df[['coord-src-0', 'coord-src-1', 'coord-dst-0', 'coord-dst-1']].to_numpy()
assert (
np.all(coords == non_main_edge_start + non_main_edge_finish)
or np.all(coords == non_main_edge_finish + non_main_edge_start)
)

0 comments on commit ac04540

Please sign in to comment.