From 5de5f20f2558b71df63784add269061035a2e56b Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 13 Jun 2023 19:16:50 +0900 Subject: [PATCH 1/5] removed LapTrackMulti for maintainability --- src/laptrack/__init__.py | 5 +- src/laptrack/_tracking.py | 280 +++------------------------------- tests/test_data_conversion.py | 10 +- tests/test_tracking.py | 47 ++---- 4 files changed, 37 insertions(+), 305 deletions(-) diff --git a/src/laptrack/__init__.py b/src/laptrack/__init__.py index b78a2ddb..d51b82a7 100644 --- a/src/laptrack/__init__.py +++ b/src/laptrack/__init__.py @@ -3,14 +3,13 @@ __author__ = """Yohsuke T. Fukai""" __email__ = "ysk@yfukai.net" -from ._tracking import laptrack, LapTrack, LapTrackMulti, LapTrackBase +from ._tracking import laptrack, LapTrack, ParallelBackend from . import data_conversion, scores, metric_utils, datasets __all__ = [ "laptrack", - "LapTrackBase", "LapTrack", - "LapTrackMulti", + "ParallelBackend", "data_conversion", "scores", "metric_utils", diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index a4201856..7ea33018 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -1,11 +1,7 @@ """Main module for tracking.""" import logging import warnings -from abc import ABC -from abc import abstractmethod from enum import Enum -from inspect import Parameter -from inspect import signature from typing import Callable from typing import cast from typing import Dict @@ -29,7 +25,6 @@ import numpy as np import pandas as pd from scipy.spatial.distance import cdist -from scipy.sparse import coo_matrix from pydantic import BaseModel, Field, Extra @@ -274,24 +269,8 @@ def to_candidates(row): return segments_df, dist_matrix, middle_point_candidates -def _remove_no_split_merge_links(track_tree, segment_connected_edges): - for edge in segment_connected_edges: - assert len(edge) == 2 - younger, elder = edge - # if the edge is involved with branching or merging, do not remove the edge - if ( - sum([int(node[0] > younger[0]) for node in track_tree.neighbors(younger)]) - > 1 - ): - continue - if sum([int(node[0] < elder[0]) for node in track_tree.neighbors(elder)]) > 1: - continue - track_tree.remove_edge(*edge) - return track_tree - - -class LapTrackBase(BaseModel, ABC, extra=Extra.forbid): - """Tracking base class for all LAP tracker with parameters.""" +class LapTrack(BaseModel, extra=Extra.forbid): + """Tracking class for LAP tracker with parameters.""" track_dist_metric: Union[str, Callable] = Field( "sqeuclidean", @@ -547,28 +526,6 @@ def _link_gap_split_merge_from_matrix( return track_tree - @abstractmethod - def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges): - """ - Perform gap-closing and splitting/merging prediction. - - Parameters - ---------- - coords : Sequence[NumArray] - The list of coordinates of point for each frame. - The array index means (sample, dimension). - track_tree : nx.Graph - the track tree - connected_edges_list (List[List[Tuple[Tuple[int, int],Tuple[int, int]]]]): - the connected edges list - - Returns - ------- - track_tree : nx.Graph - the updated track tree - """ - ... - def predict( self, coords: Sequence[NumArray], @@ -736,11 +693,25 @@ def predict_dataframe( return track_df, split_df, merge_df + def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges): + """ + Perform gap-closing and splitting/merging prediction. -class LapTrack(LapTrackBase): - """Two-step tracking, as TrackMate and K. Jaqaman et al., Nat Methods 5, 695 (2008).""" + Parameters + ---------- + coords : Sequence[NumArray] + The list of coordinates of point for each frame. + The array index means (sample, dimension). + track_tree : nx.Graph + the track tree + connected_edges_list (List[List[Tuple[Tuple[int, int],Tuple[int, int]]]]): + the connected edges list - def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges): + Returns + ------- + track_tree : nx.Graph + the updated track tree + """ edges = list(split_edges) + list(merge_edges) if ( self.gap_closing_cost_cutoff @@ -799,219 +770,6 @@ def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges) return track_tree -class LapTrackMulti(LapTrackBase): - """Four-step tracking, performing independent "segment_connecting" steps.""" - - segment_connecting_metric: Union[str, Callable] = Field( - "sqeuclidean", - description="The metric for calculating cost to connect segment ends. " - + "See `track_dist_metric`.", - ) - segment_connecting_cost_cutoff: float = Field( - False, - description="The cost cutoff for splitting. " - + "See `gap_closing_cost_cutoff`.", - ) - - remove_no_split_merge_links: bool = Field( - False, - description="If True, remove segment connections if splitting did not happen.", - ) - - def _get_segment_connecting_matrix( - self, segments_df, force_end_nodes=[], force_start_nodes=[] - ): - return _get_segment_end_connecting_matrix( - segments_df, - 1, # only arrow 1-frame difference - self.segment_connecting_metric, - self.segment_connecting_cost_cutoff, - force_end_nodes=force_end_nodes, - force_start_nodes=force_start_nodes, - ) - - def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges): - # "multi-step" type of fitting (Y. T. Fukai (2022)) - - segments_df = _get_segment_df(coords, track_tree) - - edges = list(split_edges) + list(merge_edges) - force_end_nodes = [tuple(map(int, e[0])) for e in edges] - force_start_nodes = [tuple(map(int, e[1])) for e in edges] - - ###### gap closing step ###### - ###### split - merge step 1 ###### - - get_matrix_fns = { - "gap_closing": self._get_gap_closing_matrix, - "segment_connecting": self._get_segment_connecting_matrix, - } - - segment_connected_edges = [] - for mode, get_matrix_fn in get_matrix_fns.items(): - segments_df, gap_closing_dist_matrix = get_matrix_fn( - segments_df, - force_end_nodes=force_end_nodes, - force_start_nodes=force_start_nodes, - ) - cost_matrix = build_frame_cost_matrix( - gap_closing_dist_matrix, - track_start_cost=self.segment_start_cost, - track_end_cost=self.segment_end_cost, - ) - xs, _ = lap_optimization(cost_matrix) - - nrow = gap_closing_dist_matrix.shape[0] - ncol = gap_closing_dist_matrix.shape[1] - connections = [(i, xs[i]) for i in range(nrow) if xs[i] < ncol] - for connection in connections: - # connection ... connection segments_df.iloc[i] -> segments_df.iloc[xs[i]] - node_from = tuple( - segments_df.loc[connection[0], ["last_frame", "last_index"]] - ) - node_to = tuple( - segments_df.loc[connection[1], ["first_frame", "first_index"]] - ) - track_tree.add_edge(node_from, node_to) - if mode == "segment_connecting": - segment_connected_edges.append((node_from, node_to)) - - # regenerate segments after closing gaps - segments_df = _get_segment_df(coords, track_tree) - - ###### split - merge step 2 ###### - middle_points: Dict = {} - dist_matrices: Dict = {} - for prefix, cutoff, dist_metric in zip( - ["first", "last"], - [self.splitting_cost_cutoff, self.merging_cost_cutoff], - [self.splitting_dist_metric, self.merging_dist_metric], - ): - dist_metric_argnums = None - if callable(dist_metric): - try: - s = signature(dist_metric) - dist_metric_argnums = len( - [ - 0 - for p in s.parameters.values() - if p.kind == Parameter.POSITIONAL_OR_KEYWORD - or p.kind == Parameter.POSITIONAL_ONLY - ] - ) - except TypeError: - pass - if callable(dist_metric) and dist_metric_argnums >= 3: - logger.info("using callable dist_metric with more than 2 parameters") - # the dist_metric function is assumed to take - # (coordinate1, coordinate2, coordinate_sibring, connected by segment_connecting step) - segment_connected_nodes = [ - e[0 if prefix == "first" else 1] for e in segment_connected_edges - ] # find nodes connected by "segment_connect" steps - _coords = [ - [(*c, frame, ind) for ind, c in enumerate(coord_frame)] - for frame, coord_frame in enumerate(coords) - ] - assert np.all(c.shape[1] == _coords[0].shape[1] for c in _coords) - - # _coords ... (coordinate, frame, if connected by segment_connecting step) - def _dist_metric(c1, c2): - *_c1, frame1, ind1 = c1 - *_c2, frame2, ind2 = c2 - # for splitting case, check the yonger one - if not frame1 < frame2: - # swap frame1 and 2; always assume coordinate 1 is first - tmp = _c1, frame1, ind1 - _c1, frame1, ind1 = _c2, frame2, ind2 - _c2, frame2, ind2 = tmp - check_node = (frame1, ind1) if prefix == "first" else (frame2, ind2) - if dist_metric_argnums == 3: - return dist_metric( - np.array(_c1), - np.array(_c2), - check_node in segment_connected_nodes, - ) - else: - if prefix == "first": - # splitting sibring candidate - candidates = [ - (frame, ind) - for (frame, ind) in track_tree.neighbors((frame1, ind1)) - if frame > frame1 - ] - else: - # merging sibring candidate - candidates = [ - (frame, ind) - for (frame, ind) in track_tree.neighbors((frame2, ind2)) - if frame < frame2 - ] - - if len(candidates) == 0: - c_sib = None - else: - assert len(candidates) == 1 - c_sib = candidates[0] - return dist_metric( - np.array(_c1), - np.array(_c2), - np.array(coords[c_sib[0]][c_sib[1]]) if c_sib else None, - check_node in segment_connected_nodes, - ) - - segments_df[f"{prefix}_frame_coords"] = segments_df.apply( - lambda row: ( - *row[f"{prefix}_frame_coords"], - int(row[f"{prefix}_frame"]), - int(row[f"{prefix}_index"]), - ), - axis=1, - ) - - else: - logger.info("using callable dist_metric with 2 parameters") - _coords = coords - _dist_metric = dist_metric - - ( - segments_df, - dist_matrices[prefix], - middle_points[prefix], - ) = _get_splitting_merging_candidates( - segments_df, - _coords, - cutoff, - prefix, - _dist_metric, - force_end_nodes=force_end_nodes, - force_start_nodes=force_start_nodes, - ) - - splitting_dist_matrix = dist_matrices["first"] - merging_dist_matrix = dist_matrices["last"] - splitting_all_candidates = middle_points["first"] - merging_all_candidates = middle_points["last"] - N_segments = len(segments_df) - track_tree = self._link_gap_split_merge_from_matrix( - segments_df, - track_tree, - coo_matrix((N_segments, N_segments), dtype=np.float32), # no gap closing - splitting_dist_matrix, - merging_dist_matrix, - splitting_all_candidates, - merging_all_candidates, - ) - - ###### remove segment connections if not associated with split / merge ###### - - if self.remove_no_split_merge_links: - track_tree = _remove_no_split_merge_links( - track_tree.copy(), segment_connected_edges - ) - track_tree.add_edges_from(edges) - return track_tree - - def laptrack(coords: Sequence[NumArray], **kwargs) -> nx.Graph: """ Shorthand for calling `LapTrack.fit(coords)`. diff --git a/tests/test_data_conversion.py b/tests/test_data_conversion.py index f7000c67..3807ce3e 100644 --- a/tests/test_data_conversion.py +++ b/tests/test_data_conversion.py @@ -5,7 +5,6 @@ from laptrack import data_conversion from laptrack import LapTrack -from laptrack import LapTrackMulti def test_convert_dataframe_to_coords(): @@ -189,7 +188,12 @@ def test_convert_tree_to_dataframe(test_trees): ) -@pytest.mark.parametrize("track_class", [LapTrack, LapTrackMulti]) +@pytest.mark.parametrize( + "track_class", + [ + LapTrack, + ], +) def test_convert_tree_to_dataframe_frame_index(track_class): df = pd.DataFrame( { @@ -213,7 +217,7 @@ def test_convert_tree_to_dataframe_frame_index(track_class): assert len(np.unique(df["tree_id"])) > 1 -@pytest.mark.parametrize("track_class", [LapTrack, LapTrackMulti]) +@pytest.mark.parametrize("track_class", [LapTrack]) def test_integration(track_class): df = pd.DataFrame( { diff --git a/tests/test_tracking.py b/tests/test_tracking.py index d40120d2..f655bf2e 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -12,7 +12,6 @@ from laptrack import LapTrack from laptrack import laptrack -from laptrack import LapTrackMulti from laptrack.data_conversion import convert_tree_to_dataframe warnings.simplefilter("ignore", FutureWarning) @@ -103,12 +102,11 @@ def testdata(request, shared_datadir: str): return params, coords, edges_set -@pytest.mark.parametrize("tracker_class", [LapTrack, LapTrackMulti]) @pytest.mark.parametrize("parallel_backend", ["serial", "ray"]) -def test_reproducing_trackmate(testdata, tracker_class, parallel_backend) -> None: +def test_reproducing_trackmate(testdata, parallel_backend) -> None: params, coords, edges_set = testdata params["parallel_backend"] = parallel_backend - lt = tracker_class(**params) + lt = LapTrack(**params) track_tree = lt.predict(coords) assert edges_set == set(track_tree.edges) for n in track_tree.nodes(): @@ -128,7 +126,9 @@ def test_reproducing_trackmate(testdata, tracker_class, parallel_backend) -> Non ) ) df = pd.concat(data) - track_df, split_df, merge_df = lt.predict_dataframe(df, ["x", "y"]) + track_df, split_df, merge_df = lt.predict_dataframe( + df, ["x", "y"], only_coordinate_cols=True + ) assert not any(split_df.duplicated()) assert not any(merge_df.duplicated()) track_df2, split_df2, merge_df2 = convert_tree_to_dataframe(track_tree, coords) @@ -139,7 +139,7 @@ def test_reproducing_trackmate(testdata, tracker_class, parallel_backend) -> Non # check index offset track_df3, split_df3, merge_df3 = lt.predict_dataframe( - df, ["x", "y"], index_offset=2 + df, ["x", "y"], index_offset=2, only_coordinate_cols=True ) assert min(track_df3["track_id"]) == 2 assert min(track_df3["tree_id"]) == 2 @@ -179,35 +179,6 @@ def dist_metric(request): return lambda c1, c2, _1, _2: np.linalg.norm(c1 - c2) ** 2 -def test_multi_algorithm_reproducing_trackmate_lambda(testdata, dist_metric) -> None: - params, coords, edges_set = testdata - params = params.copy() - params.update( - dict( - track_dist_metric=lambda c1, c2: np.linalg.norm(c1 - c2) ** 2, - splitting_dist_metric=dist_metric, - merging_dist_metric=dist_metric, - ) - ) - lt = LapTrackMulti(**params) - track_tree = lt.predict(coords) - assert edges_set == set(track_tree.edges) - - -def test_multi_algorithm_reproducing_trackmate_3_arg_lambda(testdata) -> None: - params, coords, edges_set = testdata - lt = LapTrackMulti(**params) - track_tree = lt.predict(coords) - assert edges_set == set(track_tree.edges) - - -def test_multi_algorithm_reproducing_trackmate_4_arg_lambda(testdata) -> None: - params, coords, edges_set = testdata - lt = LapTrackMulti(**params) - track_tree = lt.predict(coords) - assert edges_set == set(track_tree.edges) - - def test_laptrack_function_shortcut(testdata) -> None: params, coords, edges_set = testdata lt = LapTrack(**params) @@ -270,7 +241,7 @@ def df_to_tuples(df): return tuple([tuple(map(int, v)) for v in df.values]) -@pytest.mark.parametrize("tracker_class", [LapTrack, LapTrackMulti]) +@pytest.mark.parametrize("tracker_class", [LapTrack]) def test_connected_edges(tracker_class) -> None: coords = [np.array([[10, 10], [12, 11]]), np.array([[10, 10], [13, 11]])] lt = tracker_class( @@ -302,7 +273,7 @@ def test_connected_edges(tracker_class) -> None: ) == {((10, 10), (13, 11)), ((12, 11), (10, 10))} -@pytest.mark.parametrize("tracker_class", [LapTrack, LapTrackMulti]) +@pytest.mark.parametrize("tracker_class", [LapTrack]) def test_connected_edges_splitting(tracker_class) -> None: coords = [ np.array([[10, 10], [11, 11], [13, 12]]), @@ -353,7 +324,7 @@ def test_connected_edges_splitting(tracker_class) -> None: # ((10,10),(13,11)),((10,10),(13,15)), is the splitted -@pytest.mark.parametrize("tracker_class", [LapTrack, LapTrackMulti]) +@pytest.mark.parametrize("tracker_class", [LapTrack]) def test_no_connected_node(tracker_class) -> None: coords = [np.array([[10, 10], [12, 11]]), np.array([[10, 10], [100, 11]])] lt = tracker_class( From c7792cb8cdcdadc41fe2f97fa2f09b955779523c Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 13 Jun 2023 19:21:08 +0900 Subject: [PATCH 2/5] updated test --- tests/test_tracking_routines.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/tests/test_tracking_routines.py b/tests/test_tracking_routines.py index 9843766c..e281d295 100644 --- a/tests/test_tracking_routines.py +++ b/tests/test_tracking_routines.py @@ -2,32 +2,6 @@ import numpy as np from laptrack._tracking import _get_segment_df -from laptrack._tracking import _remove_no_split_merge_links - - -def test_remove_no_split_merge_links() -> None: - test_tree = nx.Graph() - test_tree.add_edges_from( - [ - ((0, 0), (1, 0)), - ((1, 0), (2, 0)), - ((2, 0), (3, 0)), - ((3, 0), (4, 0)), - ((1, 0), (2, 1)), - ((2, 1), (3, 1)), - ] - ) - segment_connected_edges = [ - ((1, 0), (2, 0)), - ((2, 0), (3, 0)), - ((2, 1), (3, 1)), - ] - removed_edges = [ - ((2, 0), (3, 0)), - ((2, 1), (3, 1)), - ] - res_tree = _remove_no_split_merge_links(test_tree.copy(), segment_connected_edges) - assert set(test_tree.edges) - set(res_tree.edges) == set(removed_edges) def test_get_segment_df() -> None: From 7b24332f9eaa91dbcab257904d8faf599babe038 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 13 Jun 2023 19:32:16 +0900 Subject: [PATCH 3/5] refacotr --- src/laptrack/_tracking.py | 514 ++++++++++++++++++-------------------- 1 file changed, 250 insertions(+), 264 deletions(-) diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index 7ea33018..553f1b3c 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -93,182 +93,6 @@ def _get_segment_df(coords, track_tree): return segments_df -def _get_segment_end_connecting_matrix( - segments_df, - max_frame_count, - dist_metric, - cost_cutoff, - *, - force_end_nodes=[], - force_start_nodes=[], -): - """ - Generate the cost matrix for connecting segment ends. - - Parameters - ---------- - segments_df : pd.DataFrame - must have the columns "first_frame", "first_index", "first_crame_coords", "last_frame", "last_index", "last_frame_coords" - max_frame_count : int - connecting cost is set to infinity if the distance between the two ends is larger than this value - dist_metric : - the distance metric - cost_cutoff : float - the cutoff value for the cost - force_end_nodes : list of int - the indices of the segments_df that is forced to be end for future connection - force_start_nodes : list of int - the indices of the segments_df that is forced to be start for future connection - - Returns - ------- - segments_df: pd.DataFrame - the segments dataframe with additional column "gap_closing_candidates" - (index of the candidate row of segments_df, the associated costs) - - """ - if cost_cutoff: - - def to_gap_closing_candidates(row): - # if the index is in force_end_indices, do not add to gap closing candidates - if (row["last_frame"], row["last_index"]) in force_end_nodes: - return [], [] - - target_coord = row["last_frame_coords"] - frame_diff = segments_df["first_frame"] - row["last_frame"] - - # only take the elements that are within the frame difference range. - # segments in df is later than the candidate segment (row) - indices = (1 <= frame_diff) & (frame_diff <= max_frame_count) - df = segments_df[indices] - force_start = df.apply( - lambda row: (row["first_frame"], row["first_index"]) - in force_start_nodes, - axis=1, - ) - df = df[~force_start] - # do not connect to the segments that is forced to be start - # note: can use KDTree if metric is distance, - # but might not be appropriate for general metrics - # https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa - # TrackMate also uses this (trivial) implementation. - if len(df) > 0: - target_dist_matrix = cdist( - [target_coord], - np.stack(df["first_frame_coords"].values), - metric=dist_metric, - ) - assert target_dist_matrix.shape[0] == 1 - indices2 = np.where(target_dist_matrix[0] < cost_cutoff)[0] - return ( - df.index[indices2].values, - target_dist_matrix[0][indices2], - ) - else: - return [], [] - - segments_df["gap_closing_candidates"] = segments_df.apply( - to_gap_closing_candidates, axis=1 - ) - else: - segments_df["gap_closing_candidates"] = [([], [])] * len(segments_df) - - N_segments = len(segments_df) - gap_closing_dist_matrix = coo_matrix_builder( - (N_segments, N_segments), dtype=np.float32 - ) - for ind, row in segments_df.iterrows(): - candidate_inds = row["gap_closing_candidates"][0] - candidate_costs = row["gap_closing_candidates"][1] - # row ... track end, col ... track start - gap_closing_dist_matrix[(int(cast(int, ind)), candidate_inds)] = candidate_costs - - return segments_df, gap_closing_dist_matrix - - -def _get_splitting_merging_candidates( - segments_df, - coords, - cutoff, - prefix, - dist_metric, - *, - force_end_nodes=[], - force_start_nodes=[], -): - if cutoff: - - def to_candidates(row): - # if the prefix is first, this means the row is the track start, and the target is the track end - other_frame = row[f"{prefix}_frame"] + (-1 if prefix == "first" else 1) - target_coord = row[f"{prefix}_frame_coords"] - row_no_connection_nodes = ( - force_start_nodes if prefix == "first" else force_end_nodes - ) - other_no_connection_nodes = ( - force_end_nodes if prefix == "first" else force_start_nodes - ) - other_no_connection_indices = [ - n[1] for n in other_no_connection_nodes if n[0] == other_frame - ] - - if ( - row[f"{prefix}_frame"], - row[f"{prefix}_index"], - ) in row_no_connection_nodes: - return ( - [], - [], - ) # do not connect to the segments that is forced to be start or end - # note: can use KDTree if metric is distance, - # but might not be appropriate for general metrics - # https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa - if other_frame < 0 or len(coords) <= other_frame: - return [], [] - target_dist_matrix = cdist( - [target_coord], coords[other_frame], metric=dist_metric - ) - assert target_dist_matrix.shape[0] == 1 - target_dist_matrix[ - 0, other_no_connection_indices - ] = ( - np.inf - ) # do not connect to the segments that is forced to be start or end - indices = np.where(target_dist_matrix[0] < cutoff)[0] - return [(other_frame, index) for index in indices], target_dist_matrix[0][ - indices - ] - - segments_df[f"{prefix}_candidates"] = segments_df.apply(to_candidates, axis=1) - else: - segments_df[f"{prefix}_candidates"] = [([], [])] * len(segments_df) - - middle_point_candidates = np.unique( - sum( - segments_df[f"{prefix}_candidates"].apply(lambda x: list(x[0])), - [], - ), - axis=0, - ) - - N_segments = len(segments_df) - N_middle = len(middle_point_candidates) - dist_matrix = coo_matrix_builder((N_segments, N_middle), dtype=np.float32) - - middle_point_candidates_dict = { - tuple(val): i for i, val in enumerate(middle_point_candidates) - } - for ind, row in segments_df.iterrows(): - candidate_frame_indices = row[f"{prefix}_candidates"][0] - candidate_inds = [ - middle_point_candidates_dict[tuple(fi)] for fi in candidate_frame_indices - ] - candidate_costs = row[f"{prefix}_candidates"][1] - dist_matrix[(int(cast(Int, ind)), candidate_inds)] = candidate_costs - - return segments_df, dist_matrix, middle_point_candidates - - class LapTrack(BaseModel, extra=Extra.forbid): """Tracking class for LAP tracker with parameters.""" @@ -371,7 +195,7 @@ class LapTrack(BaseModel, extra=Extra.forbid): exclude=True, ) - def _link_frames( + def _predict_links( self, coords, segment_connected_edges, split_merge_edges ) -> nx.Graph: """ @@ -401,7 +225,7 @@ def _link_frames( edges_list = list(segment_connected_edges) + list(split_merge_edges) - def _link_single_frame( + def _predict_link_single_frame( frame: int, coord1: np.ndarray, coord2: np.ndarray, @@ -440,14 +264,14 @@ def _link_single_frame( if self.parallel_backend == ParallelBackend.serial: all_edges = [] for frame, (coord1, coord2) in enumerate(zip(coords[:-1], coords[1:])): - edges = _link_single_frame(frame, coord1, coord2) + edges = _predict_link_single_frame(frame, coord1, coord2) all_edges.extend(edges) elif self.parallel_backend == ParallelBackend.ray: try: import ray except ImportError: raise ImportError("Please install `ray` to use `ParallelBackend.ray`.") - remote_func = ray.remote(_link_single_frame) + remote_func = ray.remote(_predict_link_single_frame) res = [ remote_func.remote(frame, coord1, coord2) for frame, (coord1, coord2) in enumerate(zip(coords[:-1], coords[1:])) @@ -461,15 +285,177 @@ def _link_single_frame( def _get_gap_closing_matrix( self, segments_df, *, force_end_nodes=[], force_start_nodes=[] ): - return _get_segment_end_connecting_matrix( - segments_df, - self.gap_closing_max_frame_count, - self.gap_closing_dist_metric, - self.gap_closing_cost_cutoff, - force_end_nodes=force_end_nodes, - force_start_nodes=force_start_nodes, + """ + Generate the cost matrix for connecting segment ends. + + Parameters + ---------- + segments_df : pd.DataFrame + must have the columns "first_frame", "first_index", "first_crame_coords", "last_frame", "last_index", "last_frame_coords" + force_end_nodes : list of int + the indices of the segments_df that is forced to be end for future connection + force_start_nodes : list of int + the indices of the segments_df that is forced to be start for future connection + + Returns + ------- + segments_df: pd.DataFrame + the segments dataframe with additional column "gap_closing_candidates" + (index of the candidate row of segments_df, the associated costs) + gap_closing_dist_matrix: coo_matrix_builder + the cost matrix for gap closing candidates + + """ + if self.gap_closing_cost_cutoff: + + def to_gap_closing_candidates(row): + # if the index is in force_end_indices, do not add to gap closing candidates + if (row["last_frame"], row["last_index"]) in force_end_nodes: + return [], [] + + target_coord = row["last_frame_coords"] + frame_diff = segments_df["first_frame"] - row["last_frame"] + + # only take the elements that are within the frame difference range. + # segments in df is later than the candidate segment (row) + indices = (1 <= frame_diff) & ( + frame_diff <= self.gap_closing_max_frame_count + ) + df = segments_df[indices] + force_start = df.apply( + lambda row: (row["first_frame"], row["first_index"]) + in force_start_nodes, + axis=1, + ) + df = df[~force_start] + # do not connect to the segments that is forced to be start + # note: can use KDTree if metric is distance, + # but might not be appropriate for general metrics + # https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa + # TrackMate also uses this (trivial) implementation. + if len(df) > 0: + target_dist_matrix = cdist( + [target_coord], + np.stack(df["first_frame_coords"].values), + metric=self.gap_closing_dist_metric, + ) + assert target_dist_matrix.shape[0] == 1 + indices2 = np.where( + target_dist_matrix[0] < self.gap_closing_cost_cutoff + )[0] + return ( + df.index[indices2].values, + target_dist_matrix[0][indices2], + ) + else: + return [], [] + + segments_df["gap_closing_candidates"] = segments_df.apply( + to_gap_closing_candidates, axis=1 + ) + else: + segments_df["gap_closing_candidates"] = [([], [])] * len(segments_df) + + N_segments = len(segments_df) + gap_closing_dist_matrix = coo_matrix_builder( + (N_segments, N_segments), dtype=np.float32 + ) + for ind, row in segments_df.iterrows(): + candidate_inds = row["gap_closing_candidates"][0] + candidate_costs = row["gap_closing_candidates"][1] + # row ... track end, col ... track start + gap_closing_dist_matrix[ + (int(cast(int, ind)), candidate_inds) + ] = candidate_costs + + return segments_df, gap_closing_dist_matrix + + def _get_splitting_merging_candidates( + self, + segments_df, + coords, + cutoff, + prefix, + dist_metric, + *, + force_end_nodes=[], + force_start_nodes=[], + ): + if cutoff: + + def to_candidates(row): + # if the prefix is first, this means the row is the track start, and the target is the track end + other_frame = row[f"{prefix}_frame"] + (-1 if prefix == "first" else 1) + target_coord = row[f"{prefix}_frame_coords"] + row_no_connection_nodes = ( + force_start_nodes if prefix == "first" else force_end_nodes + ) + other_no_connection_nodes = ( + force_end_nodes if prefix == "first" else force_start_nodes + ) + other_no_connection_indices = [ + n[1] for n in other_no_connection_nodes if n[0] == other_frame + ] + + if ( + row[f"{prefix}_frame"], + row[f"{prefix}_index"], + ) in row_no_connection_nodes: + return ( + [], + [], + ) # do not connect to the segments that is forced to be start or end + # note: can use KDTree if metric is distance, + # but might not be appropriate for general metrics + # https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa + if other_frame < 0 or len(coords) <= other_frame: + return [], [] + target_dist_matrix = cdist( + [target_coord], coords[other_frame], metric=dist_metric + ) + assert target_dist_matrix.shape[0] == 1 + target_dist_matrix[ + 0, other_no_connection_indices + ] = ( + np.inf + ) # do not connect to the segments that is forced to be start or end + indices = np.where(target_dist_matrix[0] < cutoff)[0] + return [(other_frame, index) for index in indices], target_dist_matrix[ + 0 + ][indices] + + segments_df[f"{prefix}_candidates"] = segments_df.apply( + to_candidates, axis=1 + ) + else: + segments_df[f"{prefix}_candidates"] = [([], [])] * len(segments_df) + + middle_point_candidates = np.unique( + sum( + segments_df[f"{prefix}_candidates"].apply(lambda x: list(x[0])), + [], + ), + axis=0, ) + N_segments = len(segments_df) + N_middle = len(middle_point_candidates) + dist_matrix = coo_matrix_builder((N_segments, N_middle), dtype=np.float32) + + middle_point_candidates_dict = { + tuple(val): i for i, val in enumerate(middle_point_candidates) + } + for ind, row in segments_df.iterrows(): + candidate_frame_indices = row[f"{prefix}_candidates"][0] + candidate_inds = [ + middle_point_candidates_dict[tuple(fi)] + for fi in candidate_frame_indices + ] + candidate_costs = row[f"{prefix}_candidates"][1] + dist_matrix[(int(cast(Int, ind)), candidate_inds)] = candidate_costs + + return segments_df, dist_matrix, middle_point_candidates + def _link_gap_split_merge_from_matrix( self, segments_df, @@ -526,6 +512,82 @@ def _link_gap_split_merge_from_matrix( return track_tree + def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges): + """ + Perform gap-closing and splitting/merging prediction. + + Parameters + ---------- + coords : Sequence[NumArray] + The list of coordinates of point for each frame. + The array index means (sample, dimension). + track_tree : nx.Graph + the track tree + connected_edges_list (List[List[Tuple[Tuple[int, int],Tuple[int, int]]]]): + the connected edges list + + Returns + ------- + track_tree : nx.Graph + the updated track tree + """ + edges = list(split_edges) + list(merge_edges) + if ( + self.gap_closing_cost_cutoff + or self.splitting_cost_cutoff + or self.merging_cost_cutoff + ): + segments_df = _get_segment_df(coords, track_tree) + force_end_nodes = [tuple(map(int, e[0])) for e in edges] + force_start_nodes = [tuple(map(int, e[1])) for e in edges] + + # compute candidate for gap closing + segments_df, gap_closing_dist_matrix = self._get_gap_closing_matrix( + segments_df, + force_end_nodes=force_end_nodes, + force_start_nodes=force_start_nodes, + ) + + middle_points: Dict = {} + dist_matrices: Dict = {} + + # compute candidate for splitting and merging + for prefix, cutoff, dist_metric in zip( + ["first", "last"], + [self.splitting_cost_cutoff, self.merging_cost_cutoff], + [self.splitting_dist_metric, self.merging_dist_metric], + ): + ( + segments_df, + dist_matrices[prefix], + middle_points[prefix], + ) = self._get_splitting_merging_candidates( + segments_df, + coords, + cutoff, + prefix, + dist_metric, + force_end_nodes=force_end_nodes, + force_start_nodes=force_start_nodes, + ) + + splitting_dist_matrix = dist_matrices["first"] + merging_dist_matrix = dist_matrices["last"] + splitting_all_candidates = middle_points["first"] + merging_all_candidates = middle_points["last"] + + track_tree = self._link_gap_split_merge_from_matrix( + segments_df, + track_tree, + gap_closing_dist_matrix, + splitting_dist_matrix, + merging_dist_matrix, + splitting_all_candidates, + merging_all_candidates, + ) + track_tree.add_edges_from(edges) + return track_tree + def predict( self, coords: Sequence[NumArray], @@ -589,7 +651,7 @@ def predict( merge_edges = [] ####### Particle-particle tracking ####### - track_tree = self._link_frames( + track_tree = self._predict_links( coords, segment_connected_edges, list(split_edges) + list(merge_edges) ) track_tree = self._predict_gap_split_merge( @@ -693,82 +755,6 @@ def predict_dataframe( return track_df, split_df, merge_df - def _predict_gap_split_merge(self, coords, track_tree, split_edges, merge_edges): - """ - Perform gap-closing and splitting/merging prediction. - - Parameters - ---------- - coords : Sequence[NumArray] - The list of coordinates of point for each frame. - The array index means (sample, dimension). - track_tree : nx.Graph - the track tree - connected_edges_list (List[List[Tuple[Tuple[int, int],Tuple[int, int]]]]): - the connected edges list - - Returns - ------- - track_tree : nx.Graph - the updated track tree - """ - edges = list(split_edges) + list(merge_edges) - if ( - self.gap_closing_cost_cutoff - or self.splitting_cost_cutoff - or self.merging_cost_cutoff - ): - segments_df = _get_segment_df(coords, track_tree) - force_end_nodes = [tuple(map(int, e[0])) for e in edges] - force_start_nodes = [tuple(map(int, e[1])) for e in edges] - - # compute candidate for gap closing - segments_df, gap_closing_dist_matrix = self._get_gap_closing_matrix( - segments_df, - force_end_nodes=force_end_nodes, - force_start_nodes=force_start_nodes, - ) - - middle_points: Dict = {} - dist_matrices: Dict = {} - - # compute candidate for splitting and merging - for prefix, cutoff, dist_metric in zip( - ["first", "last"], - [self.splitting_cost_cutoff, self.merging_cost_cutoff], - [self.splitting_dist_metric, self.merging_dist_metric], - ): - ( - segments_df, - dist_matrices[prefix], - middle_points[prefix], - ) = _get_splitting_merging_candidates( - segments_df, - coords, - cutoff, - prefix, - dist_metric, - force_end_nodes=force_end_nodes, - force_start_nodes=force_start_nodes, - ) - - splitting_dist_matrix = dist_matrices["first"] - merging_dist_matrix = dist_matrices["last"] - splitting_all_candidates = middle_points["first"] - merging_all_candidates = middle_points["last"] - - track_tree = self._link_gap_split_merge_from_matrix( - segments_df, - track_tree, - gap_closing_dist_matrix, - splitting_dist_matrix, - merging_dist_matrix, - splitting_all_candidates, - merging_all_candidates, - ) - track_tree.add_edges_from(edges) - return track_tree - def laptrack(coords: Sequence[NumArray], **kwargs) -> nx.Graph: """ From fc1cb3ed0f01333e9470f9e9b49bd146ce876dce Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 13 Jun 2023 19:37:23 +0900 Subject: [PATCH 4/5] added parallel computation for gap closing and splitting --- src/laptrack/_tracking.py | 43 +++++++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index 553f1b3c..93cf0742 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -277,6 +277,11 @@ def _predict_link_single_frame( for frame, (coord1, coord2) in enumerate(zip(coords[:-1], coords[1:])) ] all_edges = sum(ray.get(res), []) + else: + raise ValueError( + f"Unknown parallel backend {self.parallel_backend}. " + + f"Must be one of {', '.join([ps.name for ps in ParallelBackend])}." + ) track_tree.add_edges_from(all_edges) track_tree.add_edges_from(segment_connected_edges) @@ -350,9 +355,22 @@ def to_gap_closing_candidates(row): else: return [], [] - segments_df["gap_closing_candidates"] = segments_df.apply( - to_gap_closing_candidates, axis=1 - ) + if self.parallel_backend == ParallelBackend.serial: + segments_df["gap_closing_candidates"] = segments_df.apply( + to_gap_closing_candidates, axis=1 + ) + elif self.parallel_backend == ParallelBackend.ray: + try: + import ray + except ImportError: + raise ImportError( + "Please install `ray` to use `ParallelBackend.ray`." + ) + remote_func = ray.remote(to_gap_closing_candidates) + res = [remote_func.remote(row) for _, row in segments_df.iterrows()] + segments_df["gap_closing_candidates"] = ray.get(res) + else: + raise ValueError(f"Unknown parallel_backend {self.parallel_backend}. ") else: segments_df["gap_closing_candidates"] = [([], [])] * len(segments_df) @@ -424,9 +442,22 @@ def to_candidates(row): 0 ][indices] - segments_df[f"{prefix}_candidates"] = segments_df.apply( - to_candidates, axis=1 - ) + if self.parallel_backend == ParallelBackend.serial: + segments_df[f"{prefix}_candidates"] = segments_df.apply( + to_candidates, axis=1 + ) + elif self.parallel_backend == ParallelBackend.ray: + try: + import ray + except ImportError: + raise ImportError( + "Please install `ray` to use `ParallelBackend.ray`." + ) + remote_func = ray.remote(to_candidates) + res = [remote_func.remote(row) for _, row in segments_df.iterrows()] + segments_df[f"{prefix}_candidates"] = ray.get(res) + else: + raise ValueError(f"Unknown parallel_backend {self.parallel_backend}. ") else: segments_df[f"{prefix}_candidates"] = [([], [])] * len(segments_df) From 0186ec944ca1789e5c550d3768a5980196c700ee Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 13 Jun 2023 19:40:08 +0900 Subject: [PATCH 5/5] added test for wrong backend --- tests/test_tracking.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_tracking.py b/tests/test_tracking.py index f655bf2e..691edeb5 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -237,6 +237,11 @@ def test_no_accepting_wrong_argments() -> None: lt = LapTrack(fugafuga=True) +def test_no_accepting_wrong_backend() -> None: + with pytest.raises(ValidationError): + lt = LapTrack(parallel_backend="hogehoge") + + def df_to_tuples(df): return tuple([tuple(map(int, v)) for v in df.values])