Skip to content

Commit

Permalink
Initial work on networkx-based pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
jni committed Oct 16, 2023
1 parent 482d2d1 commit eac5da3
Showing 1 changed file with 142 additions and 96 deletions.
238 changes: 142 additions & 96 deletions src/skan/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy.typing as npt
import numba
import warnings
from typing import Tuple
from typing import Tuple, Callable
from .nputil import _raveled_offsets_and_distances
from .summary_utils import find_main_branches

Expand Down Expand Up @@ -1220,15 +1220,60 @@ def skeleton_to_nx(skeleton: Skeleton, summary: pd.DataFrame | None = None):
return g


def _merge_paths(p1: npt.NDArray, p2: npt.NDArray):
"""Join two paths together that have a common endpoint."""
return np.concatenate([p1[:-1], p2], axis=0)


def _merge_edges(g: nx.Graph, e1: tuple[int], e2: tuple[int]):
middle_node = set(e1) & set(e2)
new_edge = sorted(
(set(e1) | set(e2)) - {middle_node},
key=lambda i: i in e2,
)
d1 = g.edges[e1]
d2 = g.edges[e2]
p1 = d1['path'] if e1[1] == middle_node else d1['path'][::-1]
p2 = d2['path'] if e2[0] == middle_node else d2['path'][::-1]
n1 = len(d1['path'])
n2 = len(d2['path'])
new_edge_values = {
'skeleton_id':
g.edges[e1]['skeleton_id'],
'node_id_src':
new_edge[0],
'node_id_dst':
new_edge[1],
'branch_distance':
d1['branch_distance'] + d2['branch_distance'],
'branch_type':
min(d1['branch_type'], d2['branch_type']),
'mean_pixel_value': (
n1 * d1['mean_pixel_value'] + n2 * d2['mean_pixel_value']
) / (n1+n2),
'stdev_pixel_value':
np.sqrt((
d1['stdev_pixel_value']**2 *
(n1-1) + d2['stdev_pixel_value']**2 * (n2-1)
) / (n1+n2-1)),
'path':
_merge_paths(p1, p2),
}
g.add_edge(new_edge[0], new_edge[1], **new_edge_values)
g.remove_node(middle_node)


def _remove_simple_path_nodes(g):
"""Remove any nodes of degree 2 by merging their incident edges."""
to_remove = [n for n in g.nodes if g.degree(n) == 2]
for u in to_remove:
v, w = g[u].keys()
_merge_edges(g, (u, v), (u, w))


def iteratively_prune_paths(
skeleton: np.ndarray | Skeleton,
min_skeleton: int = 1,
spacing: int = 1,
source_image: np.ndarray = None,
keep_images: bool = True,
value_is_height: bool = False,
find_main_branch: bool = True,
imgname: str = None
skeleton: nx.Graph,
discard: Callable[[nx.Graph, dict], bool],
) -> Skeleton:
"""Iteratively prune a skeleton leaving the specified number of paths.
Expand All @@ -1243,100 +1288,101 @@ def iteratively_prune_paths(
----------
skeleton: np.ndarray | Skeleton
Skeleton object to be pruned, may be a binary Numpy array or a Skeleton.
min_skeleton: int
Minimum paths for a skeleton, default is 1 but you may wish to retain more.
spacing: int
Scale of pixel spacing passed to Skeleton
source_image: np.ndarray
Image from which the skeleton was generated passed to Skeleton.
keep_images: bool
Whether or not to keep the original input images (passed to Skeleton).
value_is_height: bool
Whether to consider the value of a float skeleton to be the "height" of the image (passed to Skeleton).
find_main_branch: bool
Whether to find the main branch of a skeleton. If False then skeletons can be pruned more than might be
expected. If True the longest path is identified using the find_main_branches() utility.
discard : Callable[[nx.Graph, dict], bool]
A predicate that is True if the edge should be discarded. The input is
a dictionary of all the attributes of that edge — the same as the
columns in the output of `summarize`.
Returns
-------
Skeleton
Returns a new Skeleton instance.
Graph
Returns a networkx Graph with the given edges pruned and remaining
paths merged.
"""
kwargs = {
"spacing": spacing, "source_image": source_image, "keep_images":
keep_images, "value_is_height": value_is_height
}
pruned = Skeleton(skeleton, **kwargs
) if isinstance(skeleton, np.ndarray) else skeleton
branch_data = summarize(pruned, find_main_branch=find_main_branch)

while branch_data.shape[0] > min_skeleton:
# Remove branches that have endpoint (branch_type == 1)
n_paths = branch_data.shape[0]
pruned, branch_data = _remove_branch_type(
pruned,
branch_data,
branch_type=1,
find_main_branch=find_main_branch,
**kwargs
)
# Check to see if we have a looped path with a branches, if so and the branch is shorter than the loop we
# remove it and break. Can either look for whether there are just two branches
# if branch_data.shape[0] == 2:
# length_branch_type1 = branch_data.loc[branch_data["branch-type"] ==
# 1,
# "branch-distance"].values[0]
# length_branch_type3 = branch_data.loc[branch_data["branch-type"] ==
# 3,
# "branch-distance"].values[0]
# if length_branch_type3 > length_branch_type1:
# pruned, branch_data = _remove_branch_type(
# pruned, branch_data, branch_type=1, find_main_branch=find_main_branch, **kwargs
# )
# ...or perhaps more generally whether we have just one loop left and if its length is less than other branches
if branch_data.loc[branch_data["branch-type"] == 3].shape[0] == 1:
# Extract the length of a loop
length_branch_type3 = branch_data.loc[branch_data["branch-type"] ==
3,
"branch-distance"].values[0]
# Extract indices for branches lengths less than this and prune them
pruned = pruned.prune_paths(
branch_data.loc[branch_data["branch-distance"] <
length_branch_type3].index
)
branch_data = summarize(pruned, find_main_branch=find_main_branch)

# We now need to check whether we have the desired number of branches (often 1), have to check before removing
# branches of type 3 in case this is the final, clean, loop.
if branch_data.shape[0] == min_skeleton:
break
# If not we need to remove any small side loops (branch_type == 3)
pruned, branch_data = _remove_branch_type(
pruned,
branch_data,
branch_type=3,
find_main_branch=find_main_branch,
**kwargs
)
# We don't need to check if we have a single path as that is the control check for the while loop, however we do
# need to check if we are removing anything as some skeletons of closed loops have internal branches that won't
# ever get pruned. This happens when there are internal loops to the main one so we never observe a loop with a
# single branch. The remaining branches ARE part of the main branch which is why they haven't (yet) been
# removed. We now prune those and check whether we have reduced the number of paths, if not we're done pruning.
if branch_data.shape[0] == n_paths:
pruned, branch_data = _remove_branch_type(
pruned,
branch_data,
branch_type=1,
find_main_branch=False,
**kwargs
)
# If this HASN'T removed any more branches we are done
if branch_data.shape[0] == n_paths:
break
pruned = skeleton # we start with no pruning

num_pruned = 1

while num_pruned > 0:
for_pruning = []
for u, v in pruned.edges:
attrs = pruned.edges[u, v]
if discard(pruned, attrs):
for_pruning.append((u, v))
num_pruned = len(for_pruning)
pruned.remove_edges_from(for_pruning)
_remove_simple_path_nodes(pruned)
return pruned


# Below code needs to be turned into a discard predicate callback
# while branch_data.shape[0] > min_skeleton:
# # Remove branches that have endpoint (branch_type == 1)
# n_paths = branch_data.shape[0]
# pruned, branch_data = _remove_branch_type(
# pruned,
# branch_data,
# branch_type=1,
# find_main_branch=find_main_branch,
# **kwargs
# )
# # Check to see if we have a looped path with a branches, if so and the branch is shorter than the loop we
# # remove it and break. Can either look for whether there are just two branches
# # if branch_data.shape[0] == 2:
# # length_branch_type1 = branch_data.loc[branch_data["branch-type"] ==
# # 1,
# # "branch-distance"].values[0]
# # length_branch_type3 = branch_data.loc[branch_data["branch-type"] ==
# # 3,
# # "branch-distance"].values[0]
# # if length_branch_type3 > length_branch_type1:
# # pruned, branch_data = _remove_branch_type(
# # pruned, branch_data, branch_type=1, find_main_branch=find_main_branch, **kwargs
# # )
# # ...or perhaps more generally whether we have just one loop left and if its length is less than other branches
# if branch_data.loc[branch_data["branch-type"] == 3].shape[0] == 1:
# # Extract the length of a loop
# length_branch_type3 = branch_data.loc[branch_data["branch-type"] ==
# 3,
# "branch-distance"].values[0]
# # Extract indices for branches lengths less than this and prune them
# pruned = pruned.prune_paths(
# branch_data.loc[branch_data["branch-distance"] <
# length_branch_type3].index
# )
# branch_data = summarize(pruned, find_main_branch=find_main_branch)
#
# # We now need to check whether we have the desired number of branches (often 1), have to check before removing
# # branches of type 3 in case this is the final, clean, loop.
# if branch_data.shape[0] == min_skeleton:
# break
# # If not we need to remove any small side loops (branch_type == 3)
# pruned, branch_data = _remove_branch_type(
# pruned,
# branch_data,
# branch_type=3,
# find_main_branch=find_main_branch,
# **kwargs
# )
# # We don't need to check if we have a single path as that is the control check for the while loop, however we do
# # need to check if we are removing anything as some skeletons of closed loops have internal branches that won't
# # ever get pruned. This happens when there are internal loops to the main one so we never observe a loop with a
# # single branch. The remaining branches ARE part of the main branch which is why they haven't (yet) been
# # removed. We now prune those and check whether we have reduced the number of paths, if not we're done pruning.
# if branch_data.shape[0] == n_paths:
# pruned, branch_data = _remove_branch_type(
# pruned,
# branch_data,
# branch_type=1,
# find_main_branch=False,
# **kwargs
# )
# # If this HASN'T removed any more branches we are done
# if branch_data.shape[0] == n_paths:
# break
# return pruned


def _remove_branch_type(
skeleton: Skeleton, branch_data: pd.DataFrame, branch_type: int,
find_main_branch: bool, **kwargs
Expand Down

0 comments on commit eac5da3

Please sign in to comment.