diff --git a/docs/usage.rst b/docs/usage.rst index fd1e48a3..d4a37c43 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -5,4 +5,4 @@ Example usage: tracking pre-detected spots .. literalinclude:: ../examples/lap_tracking_example.py :language: python - :emphasize-lines: 29-35 + :emphasize-lines: 29-37 diff --git a/examples/brownian_motion_tracking.py b/examples/brownian_motion_tracking.py index 145a3311..d1482f6f 100644 --- a/examples/brownian_motion_tracking.py +++ b/examples/brownian_motion_tracking.py @@ -4,7 +4,7 @@ import numpy as np from matplotlib import pyplot as plt -from laptrack import laptrack +from laptrack import LapTrack track_length = 100 track_count = 10 @@ -41,7 +41,8 @@ plt.legend() spots = [np.array([pos[t] for pos in brownian_poss]) for t in range(track_length)] -tree = laptrack(spots) +lt = LapTrack() +tree = lt.predict(spots) #%% # noqa: for edge in tree.edges(): if (edge[0][0] + 1 != edge[1][0]) or (edge[0][1] != edge[1][1]): diff --git a/examples/lap_tracking_example.py b/examples/lap_tracking_example.py index b2d8bbd7..7d9e90fd 100644 --- a/examples/lap_tracking_example.py +++ b/examples/lap_tracking_example.py @@ -2,7 +2,7 @@ import pandas as pd -from laptrack import laptrack +from laptrack import LapTrack script_path = path.dirname(path.realpath(__file__)) filename = "../tests/data/trackmate_tracks_with_splitting_spots.csv" @@ -26,13 +26,15 @@ # ] max_distance = 15 -track_tree = laptrack( - coords, +lt = LapTrack( track_dist_metric="sqeuclidean", splitting_dist_metric="sqeuclidean", track_cost_cutoff=max_distance**2, splitting_cost_cutoff=max_distance**2, ) +track_tree = lt.predict( + coords, +) for edge in track_tree.edges(): print(edge) diff --git a/examples/napari_example.py b/examples/napari_example.py index 5cd66a6b..f9af3dfa 100644 --- a/examples/napari_example.py +++ b/examples/napari_example.py @@ -6,7 +6,7 @@ from matplotlib import pyplot as plt from skimage.feature import blob_log -from laptrack import laptrack +from laptrack import LapTrack #%% # noqa: viewer = napari.Viewer() @@ -65,7 +65,8 @@ # %% spots_for_tracking = [spots[spots[:, 0] == j][:, 1:] for j in range(track_length)] -track_tree = laptrack(spots_for_tracking) +lt = LapTrack() +track_tree = lt.predict(spots_for_tracking) tracks = [] for i, segment in enumerate(nx.connected_components(track_tree)): @@ -105,7 +106,7 @@ for j in np.sort(np.unique(extracted_spots[:, 0])) ] -test_track = laptrack(extracted_spots_for_tracking, gap_closing_cost_cutoff=False) +test_track = lt.predict(extracted_spots_for_tracking, gap_closing_cost_cutoff=False) for edge in test_track.edges(): _e = np.array(edge) pos = [extracted_spots_for_tracking[frame][ind] for frame, ind in edge] diff --git a/poetry.lock b/poetry.lock index 5cfe38f7..95080383 100644 --- a/poetry.lock +++ b/poetry.lock @@ -977,6 +977,21 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +[[package]] +name = "pydantic" +version = "1.9.1" +description = "Data validation and settings management using python type hints" +category = "main" +optional = false +python-versions = ">=3.6.1" + +[package.dependencies] +typing-extensions = ">=3.7.4.3" + +[package.extras] +dotenv = ["python-dotenv (>=0.10.4)"] +email = ["email-validator (>=1.0.3)"] + [[package]] name = "pygments" version = "2.12.0" @@ -1567,7 +1582,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.8.1,<3.11" -content-hash = "59b62da15bd2197a3cbc3495e64360f8753505e5c367418696f76ee876d8571e" +content-hash = "619b612f720216e603cf3f78bf0bdf2c313a2a48aa9da6ef0e849ef13e9fc520" [metadata.files] alabaster = [ @@ -2234,6 +2249,43 @@ pycparser = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] +pydantic = [ + {file = "pydantic-1.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c8098a724c2784bf03e8070993f6d46aa2eeca031f8d8a048dff277703e6e193"}, + {file = "pydantic-1.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c320c64dd876e45254bdd350f0179da737463eea41c43bacbee9d8c9d1021f11"}, + {file = "pydantic-1.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18f3e912f9ad1bdec27fb06b8198a2ccc32f201e24174cec1b3424dda605a310"}, + {file = "pydantic-1.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c11951b404e08b01b151222a1cb1a9f0a860a8153ce8334149ab9199cd198131"}, + {file = "pydantic-1.9.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8bc541a405423ce0e51c19f637050acdbdf8feca34150e0d17f675e72d119580"}, + {file = "pydantic-1.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e565a785233c2d03724c4dc55464559639b1ba9ecf091288dd47ad9c629433bd"}, + {file = "pydantic-1.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:a4a88dcd6ff8fd47c18b3a3709a89adb39a6373f4482e04c1b765045c7e282fd"}, + {file = "pydantic-1.9.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:447d5521575f18e18240906beadc58551e97ec98142266e521c34968c76c8761"}, + {file = "pydantic-1.9.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:985ceb5d0a86fcaa61e45781e567a59baa0da292d5ed2e490d612d0de5796918"}, + {file = "pydantic-1.9.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:059b6c1795170809103a1538255883e1983e5b831faea6558ef873d4955b4a74"}, + {file = "pydantic-1.9.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:d12f96b5b64bec3f43c8e82b4aab7599d0157f11c798c9f9c528a72b9e0b339a"}, + {file = "pydantic-1.9.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:ae72f8098acb368d877b210ebe02ba12585e77bd0db78ac04a1ee9b9f5dd2166"}, + {file = "pydantic-1.9.1-cp36-cp36m-win_amd64.whl", hash = "sha256:79b485767c13788ee314669008d01f9ef3bc05db9ea3298f6a50d3ef596a154b"}, + {file = "pydantic-1.9.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:494f7c8537f0c02b740c229af4cb47c0d39840b829ecdcfc93d91dcbb0779892"}, + {file = "pydantic-1.9.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0f047e11febe5c3198ed346b507e1d010330d56ad615a7e0a89fae604065a0e"}, + {file = "pydantic-1.9.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:969dd06110cb780da01336b281f53e2e7eb3a482831df441fb65dd30403f4608"}, + {file = "pydantic-1.9.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:177071dfc0df6248fd22b43036f936cfe2508077a72af0933d0c1fa269b18537"}, + {file = "pydantic-1.9.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9bcf8b6e011be08fb729d110f3e22e654a50f8a826b0575c7196616780683380"}, + {file = "pydantic-1.9.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a955260d47f03df08acf45689bd163ed9df82c0e0124beb4251b1290fa7ae728"}, + {file = "pydantic-1.9.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9ce157d979f742a915b75f792dbd6aa63b8eccaf46a1005ba03aa8a986bde34a"}, + {file = "pydantic-1.9.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0bf07cab5b279859c253d26a9194a8906e6f4a210063b84b433cf90a569de0c1"}, + {file = "pydantic-1.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d93d4e95eacd313d2c765ebe40d49ca9dd2ed90e5b37d0d421c597af830c195"}, + {file = "pydantic-1.9.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1542636a39c4892c4f4fa6270696902acb186a9aaeac6f6cf92ce6ae2e88564b"}, + {file = "pydantic-1.9.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a9af62e9b5b9bc67b2a195ebc2c2662fdf498a822d62f902bf27cccb52dbbf49"}, + {file = "pydantic-1.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fe4670cb32ea98ffbf5a1262f14c3e102cccd92b1869df3bb09538158ba90fe6"}, + {file = "pydantic-1.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:9f659a5ee95c8baa2436d392267988fd0f43eb774e5eb8739252e5a7e9cf07e0"}, + {file = "pydantic-1.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b83ba3825bc91dfa989d4eed76865e71aea3a6ca1388b59fc801ee04c4d8d0d6"}, + {file = "pydantic-1.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1dd8fecbad028cd89d04a46688d2fcc14423e8a196d5b0a5c65105664901f810"}, + {file = "pydantic-1.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02eefd7087268b711a3ff4db528e9916ac9aa18616da7bca69c1871d0b7a091f"}, + {file = "pydantic-1.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7eb57ba90929bac0b6cc2af2373893d80ac559adda6933e562dcfb375029acee"}, + {file = "pydantic-1.9.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:4ce9ae9e91f46c344bec3b03d6ee9612802682c1551aaf627ad24045ce090761"}, + {file = "pydantic-1.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:72ccb318bf0c9ab97fc04c10c37683d9eea952ed526707fabf9ac5ae59b701fd"}, + {file = "pydantic-1.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:61b6760b08b7c395975d893e0b814a11cf011ebb24f7d869e7118f5a339a82e1"}, + {file = "pydantic-1.9.1-py3-none-any.whl", hash = "sha256:4988c0f13c42bfa9ddd2fe2f569c9d54646ce84adc5de84228cfe83396f3bd58"}, + {file = "pydantic-1.9.1.tar.gz", hash = "sha256:1ed987c3ff29fff7fd8c3ea3a3ea877ad310aae2ef9889a119e22d3f2db0691a"}, +] pygments = [ {file = "Pygments-2.12.0-py3-none-any.whl", hash = "sha256:dc9c10fb40944260f6ed4c688ece0cd2048414940f1cea51b8b226318411c519"}, {file = "Pygments-2.12.0.tar.gz", hash = "sha256:5eb116118f9612ff1ee89ac96437bb6b49e8f04d8a13b514ba26f620208e26eb"}, diff --git a/pyproject.toml b/pyproject.toml index d246718f..10e1d5c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "laptrack" -version = "0.1.6" +version = "0.1.7-alpha.2" description = "LapTrack" authors = ["Yohsuke Fukai "] license = "BSD-3-Clause" @@ -32,6 +32,7 @@ scipy = "^1.7.0" networkx = "^2.6.1" pandas = "^1.3.1" typing-extensions = "^3.10.0" +pydantic = "^1.9.1" [tool.poetry.dev-dependencies] pytest = "^6.2.4" diff --git a/src/laptrack/__init__.py b/src/laptrack/__init__.py index 9439e995..cee91112 100644 --- a/src/laptrack/__init__.py +++ b/src/laptrack/__init__.py @@ -3,6 +3,6 @@ __author__ = """Yohsuke T. Fukai""" __email__ = "ysk@yfukai.net" -from ._tracking import laptrack +from ._tracking import laptrack, LapTrack, LapTrackMulti -__all__ = ["laptrack"] +__all__ = ["laptrack", "LapTrack", "LapTrackMulti"] diff --git a/src/laptrack/_cost_matrix.py b/src/laptrack/_cost_matrix.py index 8ad9121d..448226da 100644 --- a/src/laptrack/_cost_matrix.py +++ b/src/laptrack/_cost_matrix.py @@ -2,7 +2,7 @@ from typing import Union import numpy as np -from scipy.sparse.coo import coo_matrix +from scipy.sparse import coo_matrix from ._typing_utils import Float from ._typing_utils import Matrix @@ -67,7 +67,7 @@ def build_segment_cost_matrix( alternative_cost_factor: Float = 1.05, alternative_cost_percentile: Float = 90, alternative_cost_percentile_interpolation: str = "lower", -) -> coo_matrix: +) -> Optional[coo_matrix]: """Build sparce array for segment-linking cost matrix. Parameters @@ -141,7 +141,7 @@ def build_segment_cost_matrix( # XXX seems numpy / mypy is over-strict here. Will fix later. C.data, # type: ignore alternative_cost_percentile, - interpolation=alternative_cost_percentile_interpolation, + method=alternative_cost_percentile_interpolation, ) * alternative_cost_factor ) diff --git a/src/laptrack/_optimization.py b/src/laptrack/_optimization.py index aa2b6bf8..c958a6a1 100644 --- a/src/laptrack/_optimization.py +++ b/src/laptrack/_optimization.py @@ -1,7 +1,7 @@ from typing import Tuple import lap -from scipy.sparse.csr import csr_matrix +from scipy.sparse import csr_matrix from ._typing_utils import FloatArray from ._typing_utils import Int diff --git a/src/laptrack/_tracking.py b/src/laptrack/_tracking.py index 9d495106..8d561c5e 100644 --- a/src/laptrack/_tracking.py +++ b/src/laptrack/_tracking.py @@ -1,4 +1,9 @@ """Main module for tracking.""" +import logging +from abc import ABC +from abc import abstractmethod +from inspect import Parameter +from inspect import signature from typing import Callable from typing import cast from typing import Dict @@ -20,323 +25,655 @@ import numpy as np import pandas as pd from scipy.spatial.distance import cdist +from scipy.sparse import coo_matrix +from pydantic import BaseModel, Field + from ._cost_matrix import build_frame_cost_matrix, build_segment_cost_matrix from ._optimization import lap_optimization -from ._typing_utils import Float from ._typing_utils import FloatArray from ._typing_utils import Int from ._utils import coo_matrix_builder +logger = logging.getLogger(__name__) + + +def _get_segment_df(coords, track_tree): + # linking between tracks + segments = list(nx.connected_components(track_tree)) + first_nodes = np.array( + list(map(lambda segment: min(segment, key=lambda node: node[0]), segments)) + ) + last_nodes = np.array( + list(map(lambda segment: max(segment, key=lambda node: node[0]), segments)) + ) + segments_df = pd.DataFrame( + { + "segment": segments, + "first_frame": first_nodes[:, 0], + "first_index": first_nodes[:, 1], + "last_frame": last_nodes[:, 0], + "last_index": last_nodes[:, 1], + } + ).reset_index() + + for prefix in ["first", "last"]: + segments_df[f"{prefix}_frame_coords"] = segments_df.apply( + lambda row: coords[row[f"{prefix}_frame"]][row[f"{prefix}_index"]], + axis=1, + ) + return segments_df + + +def _get_segment_end_connecting_matrix( + segments_df, max_frame_count, dist_metric, cost_cutoff +): + if cost_cutoff: + + def to_gap_closing_candidates(row): + target_coord = row["last_frame_coords"] + frame_diff = segments_df["first_frame"] - row["last_frame"] + indices = (1 <= frame_diff) & (frame_diff <= max_frame_count) + df = segments_df[indices] + # note: can use KDTree if metric is distance, + # but might not be appropriate for general metrics + # https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa + # TrackMate also uses this (trivial) implementation. + if len(df) > 0: + target_dist_matrix = cdist( + [target_coord], + np.stack(df["first_frame_coords"].values), + metric=dist_metric, + ) + assert target_dist_matrix.shape[0] == 1 + indices2 = np.where(target_dist_matrix[0] < cost_cutoff)[0] + return ( + df.index[indices2].values, + target_dist_matrix[0][indices2], + ) + else: + return [], [] -def laptrack( - coords: Sequence[FloatArray], - track_dist_metric: Union[str, Callable] = "sqeuclidean", - splitting_dist_metric: Union[str, Callable] = "sqeuclidean", - merging_dist_metric: Union[str, Callable] = "sqeuclidean", - alternative_cost_factor: Float = 1.05, - alternative_cost_percentile: Float = 90, - alternative_cost_percentile_interpolation: str = "lower", - track_cost_cutoff: Float = 15**2, - track_start_cost: Optional[Float] = None, # b in Jaqaman et al 2008 NMeth. - track_end_cost: Optional[Float] = None, # d in Jaqaman et al 2008 NMeth. - gap_closing_cost_cutoff: Union[Float, Literal[False]] = 15**2, - gap_closing_max_frame_count: Int = 2, - splitting_cost_cutoff: Union[Float, Literal[False]] = False, - no_splitting_cost: Optional[Float] = None, # d' in Jaqaman et al 2008 NMeth. - merging_cost_cutoff: Union[Float, Literal[False]] = False, - no_merging_cost: Optional[Float] = None, # b' in Jaqaman et al 2008 NMeth. -) -> nx.Graph: - """Track points by solving linear assignment problem. + segments_df["gap_closing_candidates"] = segments_df.apply( + to_gap_closing_candidates, axis=1 + ) + else: + segments_df["gap_closing_candidates"] = [([], [])] * len(segments_df) + + N_segments = len(segments_df) + gap_closing_dist_matrix = coo_matrix_builder( + (N_segments, N_segments), dtype=np.float32 + ) + for ind, row in segments_df.iterrows(): + candidate_inds = row["gap_closing_candidates"][0] + candidate_costs = row["gap_closing_candidates"][1] + # row ... track end, col ... track start + gap_closing_dist_matrix[(int(cast(int, ind)), candidate_inds)] = candidate_costs + + return segments_df, gap_closing_dist_matrix + + +def _get_splitting_merging_candidates( + segments_df, + coords, + cutoff, + prefix, + dist_metric, +): + if cutoff: + + def to_candidates(row): + target_coord = row[f"{prefix}_frame_coords"] + frame = row[f"{prefix}_frame"] + (-1 if prefix == "first" else 1) + # note: can use KDTree if metric is distance, + # but might not be appropriate for general metrics + # https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa + if frame < 0 or len(coords) <= frame: + return [], [] + target_dist_matrix = cdist( + [target_coord], coords[frame], metric=dist_metric + ) + assert target_dist_matrix.shape[0] == 1 + indices = np.where(target_dist_matrix[0] < cutoff)[0] + return [(frame, index) for index in indices], target_dist_matrix[0][indices] + + segments_df[f"{prefix}_candidates"] = segments_df.apply(to_candidates, axis=1) + else: + segments_df[f"{prefix}_candidates"] = [([], [])] * len(segments_df) + + middle_point_candidates = np.unique( + sum( + segments_df[f"{prefix}_candidates"].apply(lambda x: list(x[0])), + [], + ), + axis=0, + ) + + N_segments = len(segments_df) + N_middle = len(middle_point_candidates) + dist_matrix = coo_matrix_builder((N_segments, N_middle), dtype=np.float32) + + middle_point_candidates_dict = { + tuple(val): i for i, val in enumerate(middle_point_candidates) + } + for ind, row in segments_df.iterrows(): + candidate_frame_indices = row[f"{prefix}_candidates"][0] + candidate_inds = [ + middle_point_candidates_dict[tuple(fi)] for fi in candidate_frame_indices + ] + candidate_costs = row[f"{prefix}_candidates"][1] + dist_matrix[(int(cast(Int, ind)), candidate_inds)] = candidate_costs + + return segments_df, dist_matrix, middle_point_candidates + + +def _remove_no_split_merge_links(track_tree, segment_connected_edges): + for edge in segment_connected_edges: + assert len(edge) == 2 + younger, elder = edge + # if the edge is involved with branching or merging, do not remove the edge + if ( + sum([int(node[0] > younger[0]) for node in track_tree.neighbors(younger)]) + > 1 + ): + continue + if sum([int(node[0] < elder[0]) for node in track_tree.neighbors(elder)]) > 1: + continue + track_tree.remove_edge(*edge) + return track_tree - Parameters - ---------- - coords : Sequence[FloatArray] - The list of coordinates of point for each frame. - The array index means (sample, dimension). - track_dist_metric : str or Callable, optional - The metric for calculating track linking cost, - by default 'sqeuclidean' (squared euclidean distance). - See documentation for `scipy.spatial.distance.cdist` for accepted values. +class LapTrackBase(BaseModel, ABC): + track_dist_metric: Union[str, Callable] = Field( + "sqeuclidean", + description="The metric for calculating track linking cost, " + + "See documentation for `scipy.spatial.distance.cdist` for accepted values.", + ) + splitting_dist_metric: Union[str, Callable] = Field( + "sqeuclidean", + description="The metric for calculating splitting cost." + + "See `track_dist_metric`", + ) + merging_dist_metric: Union[str, Callable] = Field( + "sqeuclidean", + description="The metric for calculating merging cost." + + "See `track_dist_metric`", + ) + + alternative_cost_factor: float = Field( + 1.05, + description="The factor to calculate the alternative costs" + + "(b,d,b',d' in Jaqaman et al 2008 NMeth)", + ) + alternative_cost_percentile: float = Field( + 90, + description="The percentile to calculate the alternative costs" + + "(b,d,b',d' in Jaqaman et al 2008 NMeth)", + ) + alternative_cost_percentile_interpolation: str = Field( + "lower", + description="The percentile interpolation to calculate the alternative costs" + + "(b,d,b',d' in Jaqaman et al 2008 NMeth)." + + "See `numpy.percentile` for accepted values.", + ) + + track_cost_cutoff: float = Field( + 15**2, + description="The cost cutoff for the connected points in the track." + + "For default cases with `dist_metric='sqeuclidean'`," + + "this value should be squared maximum distance.", + ) + track_start_cost: Optional[float] = Field( + None, # b in Jaqaman et al 2008 NMeth. + description="The cost for starting the track (b in Jaqaman et al 2008 NMeth)," + + "if None, automatically estimated", + ) + track_end_cost: Optional[float] = Field( + None, # b in Jaqaman et al 2008 NMeth. + description="The cost for ending the track (b in Jaqaman et al 2008 NMeth)," + + "if None, automatically estimated", + ) + + gap_closing_cost_cutoff: Union[Literal[False], float] = Field( + 15**2, + description="The cost cutoff for gap closing." + + "For default cases with `dist_metric='sqeuclidean'`," + + "this value should be squared maximum distance." + + "If False, no gap closing is allowed.", + ) + + gap_closing_max_frame_count: int = Field( + 2, description="The maximum frame gaps, by default 2." + ) + + splitting_cost_cutoff: Union[Literal[False], float] = Field( + False, + description="The cost cutoff for splitting." + + "See `gap_closing_cost_cutoff`." + + "If False, no splitting is allowed.", + ) + + no_splitting_cost: Optional[float] = Field( + None, # d' in Jaqaman et al 2008 NMeth. + description="The cost to reject splitting, if None, automatically estimated.", + ) + + merging_cost_cutoff: Union[Literal[False], float] = Field( + False, + description="The cost cutoff for merging." + + "See `gap_closing_cost_cutoff`." + + "If False, no merging is allowed.", + ) + + no_merging_cost: Optional[float] = Field( + None, # d' in Jaqaman et al 2008 NMeth. + description="The cost to reject merging, if None, automatically estimated.", + ) + + def _link_frames(self, coords) -> nx.Graph: + """Link particles between frames according to the cost function + + Args: + coords (_type_): _description_ + + Returns: + nx.Graph: _description_ + """ + # initialize tree + track_tree = nx.Graph() + for frame, coord in enumerate(coords): + for j in range(coord.shape[0]): + track_tree.add_node((frame, j)) + + # linking between frames + for frame, (coord1, coord2) in enumerate(zip(coords[:-1], coords[1:])): + dist_matrix = cdist(coord1, coord2, metric=self.track_dist_metric) + ind = np.where(dist_matrix < self.track_cost_cutoff) + dist_matrix = coo_matrix_builder( + dist_matrix.shape, + row=ind[0], + col=ind[1], + data=dist_matrix[(*ind,)], + dtype=dist_matrix.dtype, + ) + cost_matrix = build_frame_cost_matrix( + dist_matrix, + track_start_cost=self.track_start_cost, + track_end_cost=self.track_end_cost, + ) + _, xs, _ = lap_optimization(cost_matrix) + + count1 = dist_matrix.shape[0] + count2 = dist_matrix.shape[1] + connections = [(i, xs[i]) for i in range(count1) if xs[i] < count2] + # track_start=[i for i in range(count1) if xs[i]>count2] + # track_end=[i for i in range(count2) if ys[i]>count1] + for connection in connections: + track_tree.add_edge((frame, connection[0]), (frame + 1, connection[1])) + return track_tree + + def _get_gap_closing_matrix(self, segments_df): + return _get_segment_end_connecting_matrix( + segments_df, + self.gap_closing_max_frame_count, + self.track_dist_metric, + self.gap_closing_cost_cutoff, + ) - splitting_dist_metric : str or Callable, optional - The metric for calculating splitting cost. See `track_dist_metric`. + def _link_gap_split_merge_from_matrix( + self, + segments_df, + track_tree, + gap_closing_dist_matrix, + splitting_dist_matrix, + merging_dist_matrix, + splitting_all_candidates, + merging_all_candidates, + ): + cost_matrix = build_segment_cost_matrix( + gap_closing_dist_matrix, + splitting_dist_matrix, + merging_dist_matrix, + self.track_start_cost, + self.track_end_cost, + self.no_splitting_cost, + self.no_merging_cost, + self.alternative_cost_factor, + self.alternative_cost_percentile, + self.alternative_cost_percentile_interpolation, + ) - merging_dist_metric : str or Callable, optional - The metric for calculating merging cost. See `track_dist_metric`. + if not cost_matrix is None: + _, xs, ys = lap_optimization(cost_matrix) - alternative_cost_factor : Float, optional - The factor to calculate the alternative costs - (b,d,b',d' in Jaqaman et al 2008 NMeth), by default 1.05. + M = gap_closing_dist_matrix.shape[0] + N1 = splitting_dist_matrix.shape[1] + N2 = merging_dist_matrix.shape[1] - alternative_cost_percentile : Float, optional - The percentile to calculate the alternative costs - (b,d,b',d' in Jaqaman et al 2008 NMeth), by default 90. + for ind, row in segments_df.iterrows(): + col_ind = xs[ind] + first_frame_index = (row["first_frame"], row["first_index"]) + last_frame_index = (row["last_frame"], row["last_index"]) + if col_ind < M: + target_frame_index = tuple( + segments_df.loc[col_ind, ["first_frame", "first_index"]] + ) + track_tree.add_edge(last_frame_index, target_frame_index) + elif col_ind < M + N2: + track_tree.add_edge( + last_frame_index, + tuple(merging_all_candidates[col_ind - M]), + ) + + row_ind = ys[ind] + if M <= row_ind and row_ind < M + N1: + track_tree.add_edge( + first_frame_index, + tuple(splitting_all_candidates[row_ind - M]), + ) - alternative_cost_percentile_interpolation : str, optional - The percentile interpolation to calculate the alternative costs - (b,d,b',d' in Jaqaman et al 2008 NMeth), by default "lower". - See `numpy.percentile` for accepted values. + return track_tree - track_cost_cutoff : Float, optional - The cost cutoff for the connected points in the track. - For default cases with `dist_metric="sqeuclidean"`, - this value should be squared maximum distance. - By default 15**2. + @abstractmethod + def _predict_gap_split_merge(self, coords, track_tree): + ... - track_start_cost : Float or None, optional - The cost for starting the track (b in Jaqaman et al 2008 NMeth), - by default None (automatically estimated). + def predict(self, coords) -> nx.Graph: + """Predict the tracking graph from coordinates - track_end_cost : Float or None, optional - The cost for ending the track (d in Jaqaman et al 2008 NMeth), - by default None (automatically estimated). + Args: + coords : Sequence[FloatArray] + The list of coordinates of point for each frame. + The array index means (sample, dimension). - gap_closing_cost_cutoff : Float or False, optional - The cost cutoff for gap closing. - For default cases with `dist_metric="sqeuclidean"`, - this value should be squared maximum distance. - If False, no splitting is allowed. - By default 15**2. - gap_closing_max_frame_count : Int - The maximum frame gaps, by default 2. + Raises: + ValueError: raised for invalid coordinate formats. - splitting_cost_cutoff : Float or False, optional - The cost cutoff for the splitting points. - See `gap_closing_cost_cutoff`. - If False, no splitting is allowed. - By default False. + Returns: + nx.Graph: The graph for the tracks, whose nodes are (frame, index). + """ - no_splitting_cost : Float or None, optional - The cost to reject splitting, by default None (automatically estimated). + if any(list(map(lambda coord: coord.ndim != 2, coords))): + raise ValueError("the elements in coords must be 2-dim.") + coord_dim = coords[0].shape[1] + if any(list(map(lambda coord: coord.shape[1] != coord_dim, coords))): + raise ValueError("the second dimension in coords must have the same size") - merging_cost_cutoff : Float or False, optional - The cost cutoff for the merging points. - See `gap_closing_cost_cutoff`. - If False, no merging is allowed. - By default False. + ####### Particle-particle tracking ####### + track_tree = self._link_frames(coords) + track_tree = self._predict_gap_split_merge(coords, track_tree) + return track_tree - no_merging_cost : Float or None, optional - The cost to reject meging, by default None (automatically estimated). - Returns - ------- - tracks : networkx.Graph - The graph for the tracks, whose nodes are (frame, index). +class LapTrack(LapTrackBase): + def _predict_gap_split_merge(self, coords, track_tree): + """one-step fitting, as TrackMate and K. Jaqaman et al., Nat Methods 5, 695 (2008). - """ - if any(list(map(lambda coord: coord.ndim != 2, coords))): - raise ValueError("the elements in coords must be 2-dim.") - coord_dim = coords[0].shape[1] - if any(list(map(lambda coord: coord.shape[1] != coord_dim, coords))): - raise ValueError("the second dimension in coords must have the same size") - - # initialize tree - track_tree = nx.Graph() - for frame, coord in enumerate(coords): - for j in range(coord.shape[0]): - track_tree.add_node((frame, j)) - - # linking between frames - for frame, (coord1, coord2) in enumerate(zip(coords[:-1], coords[1:])): - dist_matrix = cdist(coord1, coord2, metric=track_dist_metric) - ind = np.where(dist_matrix < track_cost_cutoff) - dist_matrix = coo_matrix_builder( - dist_matrix.shape, - row=ind[0], - col=ind[1], - data=dist_matrix[(*ind,)], - dtype=dist_matrix.dtype, - ) - cost_matrix = build_frame_cost_matrix( - dist_matrix, - track_start_cost=track_start_cost, - track_end_cost=track_end_cost, - ) - _, xs, _ = lap_optimization(cost_matrix) - - count1 = dist_matrix.shape[0] - count2 = dist_matrix.shape[1] - connections = [(i, xs[i]) for i in range(count1) if xs[i] < count2] - # track_start=[i for i in range(count1) if xs[i]>count2] - # track_end=[i for i in range(count2) if ys[i]>count1] - for connection in connections: - track_tree.add_edge((frame, connection[0]), (frame + 1, connection[1])) - - if gap_closing_cost_cutoff or splitting_cost_cutoff or merging_cost_cutoff: - # linking between tracks - segments = list(nx.connected_components(track_tree)) - N_segments = len(segments) - first_nodes = np.array( - list(map(lambda segment: min(segment, key=lambda node: node[0]), segments)) - ) - last_nodes = np.array( - list(map(lambda segment: max(segment, key=lambda node: node[0]), segments)) - ) - segments_df = pd.DataFrame( - { - "segment": segments, - "first_frame": first_nodes[:, 0], - "first_index": first_nodes[:, 1], - "last_frame": last_nodes[:, 0], - "last_index": last_nodes[:, 1], - } - ).reset_index() - - for prefix in ["first", "last"]: - segments_df[f"{prefix}_frame_coords"] = segments_df.apply( - lambda row: coords[row[f"{prefix}_frame"]][row[f"{prefix}_index"]], - axis=1, - ) + Args: + coords : Sequence[FloatArray] + The list of coordinates of point for each frame. + The array index means (sample, dimension). + track_tree : nx.Graph + the track tree + + Returns: + track_tree : nx.Graph + the updated track tree + """ + if ( + self.gap_closing_cost_cutoff + or self.splitting_cost_cutoff + or self.merging_cost_cutoff + ): + segments_df = _get_segment_df(coords, track_tree) - # compute candidate for gap closing - if gap_closing_cost_cutoff: + # compute candidate for gap closing + segments_df, gap_closing_dist_matrix = self._get_gap_closing_matrix( + segments_df + ) - def to_gap_closing_candidates(row): - target_coord = row["last_frame_coords"] - frame_diff = segments_df["first_frame"] - row["last_frame"] - indices = (1 <= frame_diff) & ( - frame_diff <= gap_closing_max_frame_count + middle_points: Dict = {} + dist_matrices: Dict = {} + + # compute candidate for splitting and merging + for prefix, cutoff, dist_metric in zip( + ["first", "last"], + [self.splitting_cost_cutoff, self.merging_cost_cutoff], + [self.splitting_dist_metric, self.merging_dist_metric], + ): + ( + segments_df, + dist_matrices[prefix], + middle_points[prefix], + ) = _get_splitting_merging_candidates( + segments_df, coords, cutoff, prefix, dist_metric ) - df = segments_df[indices] - # note: can use KDTree if metric is distance, - # but might not be appropriate for general metrics - # https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa - # TrackMate also uses this (trivial) implementation. - if len(df) > 0: - target_dist_matrix = cdist( - [target_coord], - np.stack(df["first_frame_coords"].values), - metric=track_dist_metric, - ) - assert target_dist_matrix.shape[0] == 1 - indices2 = np.where( - target_dist_matrix[0] < gap_closing_cost_cutoff - )[0] - return df.index[indices2].values, target_dist_matrix[0][indices2] - else: - return [], [] - - segments_df["gap_closing_candidates"] = segments_df.apply( - to_gap_closing_candidates, axis=1 + + splitting_dist_matrix = dist_matrices["first"] + merging_dist_matrix = dist_matrices["last"] + splitting_all_candidates = middle_points["first"] + merging_all_candidates = middle_points["last"] + track_tree = self._link_gap_split_merge_from_matrix( + segments_df, + track_tree, + gap_closing_dist_matrix, + splitting_dist_matrix, + merging_dist_matrix, + splitting_all_candidates, + merging_all_candidates, ) - else: - segments_df["gap_closing_candidates"] = [[]] * len(segments_df) - gap_closing_dist_matrix = coo_matrix_builder( - (N_segments, N_segments), dtype=np.float32 + return track_tree + + +class LapTrackMulti(LapTrackBase): + segment_connecting_metric: Union[str, Callable] = Field( + "sqeuclidean", + description="The metric for calculating cost to connect segment ends." + + "See `track_dist_metric`.", + ) + segment_connecting_cost_cutoff: float = Field( + 15**2, + description="The cost cutoff for splitting." + "See `gap_closing_cost_cutoff`.", + ) + + remove_no_split_merge_links: bool = Field( + False, + description="if True, remove segment connections if splitting did not happen.", + ) + + def _get_segment_connecting_matrix(self, segments_df): + return _get_segment_end_connecting_matrix( + segments_df, + 1, # only arrow 1-frame difference + self.segment_connecting_metric, + self.segment_connecting_cost_cutoff, ) - for ind, row in segments_df.iterrows(): - candidate_inds = row["gap_closing_candidates"][0] - candidate_costs = row["gap_closing_candidates"][1] - # row ... track end, col ... track start - gap_closing_dist_matrix[ - (int(cast(Int, ind)), candidate_inds) - ] = candidate_costs - - all_candidates: Dict = {} - dist_matrices: Dict = {} - # compute candidate for splitting and merging + def _predict_gap_split_merge(self, coords, track_tree): + # "multi-step" type of fitting (Y. T. Fukai (2022)) + segments_df = _get_segment_df(coords, track_tree) + + ###### gap closing step ###### + ###### split - merge step 1 ###### + + get_matrix_fns = { + "gap_closing": self._get_gap_closing_matrix, + "segment_connecting": self._get_segment_connecting_matrix, + } + + segment_connected_edges = [] + for mode, get_matrix_fn in get_matrix_fns.items(): + segments_df, gap_closing_dist_matrix = get_matrix_fn(segments_df) + cost_matrix = build_frame_cost_matrix( + gap_closing_dist_matrix, + track_start_cost=self.track_start_cost, + track_end_cost=self.track_end_cost, + ) + _, xs, _ = lap_optimization(cost_matrix) + + nrow = gap_closing_dist_matrix.shape[0] + ncol = gap_closing_dist_matrix.shape[1] + connections = [(i, xs[i]) for i in range(nrow) if xs[i] < ncol] + for connection in connections: + # connection ... connection segments_df.iloc[i] -> segments_df.iloc[xs[i]] + node_from = tuple( + segments_df.loc[connection[0], ["last_frame", "last_index"]] + ) + node_to = tuple( + segments_df.loc[connection[1], ["first_frame", "first_index"]] + ) + track_tree.add_edge(node_from, node_to) + if mode == "segment_connecting": + segment_connected_edges.append((node_from, node_to)) + + # regenerate segments after closing gaps + segments_df = _get_segment_df(coords, track_tree) + + ###### split - merge step 2 ###### + middle_points: Dict = {} + dist_matrices: Dict = {} for prefix, cutoff, dist_metric in zip( ["first", "last"], - [splitting_cost_cutoff, merging_cost_cutoff], - [splitting_dist_metric, merging_dist_metric], + [self.splitting_cost_cutoff, self.merging_cost_cutoff], + [self.splitting_dist_metric, self.merging_dist_metric], ): - if cutoff: - - def to_candidates(row): - target_coord = row[f"{prefix}_frame_coords"] - frame = row[f"{prefix}_frame"] + (-1 if prefix == "first" else 1) - # note: can use KDTree if metric is distance, - # but might not be appropriate for general metrics - # https://stackoverflow.com/questions/35459306/find-points-within-cutoff-distance-of-other-points-with-scipy # noqa - if frame < 0 or len(coords) <= frame: - return [], [] - target_dist_matrix = cdist( - [target_coord], coords[frame], metric=dist_metric + dist_metric_argnums = None + if callable(dist_metric): + try: + s = signature(dist_metric) + dist_metric_argnums = len( + [ + 0 + for p in s.parameters.values() + if p.kind == Parameter.POSITIONAL_OR_KEYWORD + or p.kind == Parameter.POSITIONAL_ONLY + ] ) - assert target_dist_matrix.shape[0] == 1 - indices = np.where(target_dist_matrix[0] < cutoff)[0] - return [(frame, index) for index in indices], target_dist_matrix[0][ - indices - ] - - segments_df[f"{prefix}_candidates"] = segments_df.apply( - to_candidates, axis=1 + except TypeError: + pass + if callable(dist_metric) and dist_metric_argnums >= 3: + logger.info("using callable dist_metric with more than 2 parameters") + # the dist_metric function is assumed to take + # (coordinate1, coordinate2, coordinate_sibring, connected by segment_connecting step) + segment_connected_nodes = [ + e[0 if prefix == "first" else 1] for e in segment_connected_edges + ] # find nodes connected by "segment_connect" steps + _coords = [ + [(*c, frame, ind) for ind, c in enumerate(coord_frame)] + for frame, coord_frame in enumerate(coords) + ] + assert np.all(c.shape[1] == _coords[0].shape[1] for c in _coords) + + # _coords ... (coordinate, frame, if connected by segment_connecting step) + def _dist_metric(c1, c2): + *_c1, frame1, ind1 = c1 + *_c2, frame2, ind2 = c2 + # for splitting case, check the yonger one + if not frame1 < frame2: + # swap frame1 and 2; always assume coordinate 1 is first + tmp = _c1, frame1, ind1 + _c1, frame1, ind1 = _c2, frame2, ind2 + _c2, frame2, ind2 = tmp + check_node = (frame1, ind1) if prefix == "first" else (frame2, ind2) + if dist_metric_argnums == 3: + return dist_metric( + np.array(_c1), + np.array(_c2), + check_node in segment_connected_nodes, + ) + else: + if prefix == "first": + # splitting sibring candidate + candidates = [ + (frame, ind) + for (frame, ind) in track_tree.neighbors((frame1, ind1)) + if frame > frame1 + ] + else: + # merging sibring candidate + candidates = [ + (frame, ind) + for (frame, ind) in track_tree.neighbors((frame2, ind2)) + if frame < frame2 + ] + + if len(candidates) == 0: + c_sib = None + else: + assert len(candidates) == 1 + c_sib = candidates[0] + return dist_metric( + np.array(_c1), + np.array(_c2), + c_sib, + check_node in segment_connected_nodes, + ) + + segments_df[f"{prefix}_frame_coords"] = segments_df.apply( + lambda row: ( + *row[f"{prefix}_frame_coords"], + int(row[f"{prefix}_frame"]), + int(row[f"{prefix}_index"]), + ), + axis=1, ) - else: - segments_df[f"{prefix}_candidates"] = [([], [])] * len(segments_df) - all_candidates[prefix] = np.unique( - sum( - segments_df[f"{prefix}_candidates"].apply(lambda x: list(x[0])), [] - ), - axis=0, - ) - - N_middle = len(all_candidates[prefix]) - dist_matrices[prefix] = coo_matrix_builder( - (N_segments, N_middle), dtype=np.float32 + else: + logger.info("using callable dist_metric with 2 parameters") + _coords = coords + _dist_metric = dist_metric + + ( + segments_df, + dist_matrices[prefix], + middle_points[prefix], + ) = _get_splitting_merging_candidates( + segments_df, _coords, cutoff, prefix, _dist_metric ) - all_candidates_dict = { - tuple(val): i for i, val in enumerate(all_candidates[prefix]) - } - for ind, row in segments_df.iterrows(): - candidate_frame_indices = row[f"{prefix}_candidates"][0] - candidate_inds = [ - all_candidates_dict[tuple(fi)] for fi in candidate_frame_indices - ] - candidate_costs = row[f"{prefix}_candidates"][1] - dist_matrices[prefix][ - (int(cast(Int, ind)), candidate_inds) - ] = candidate_costs - splitting_dist_matrix = dist_matrices["first"] merging_dist_matrix = dist_matrices["last"] - splitting_all_candidates = all_candidates["first"] - merging_all_candidates = all_candidates["last"] - cost_matrix = build_segment_cost_matrix( - gap_closing_dist_matrix, + splitting_all_candidates = middle_points["first"] + merging_all_candidates = middle_points["last"] + N_segments = len(segments_df) + track_tree = self._link_gap_split_merge_from_matrix( + segments_df, + track_tree, + coo_matrix((N_segments, N_segments), dtype=np.float32), # no gap closing splitting_dist_matrix, merging_dist_matrix, - track_start_cost, - track_end_cost, - no_splitting_cost, - no_merging_cost, - alternative_cost_factor, - alternative_cost_percentile, - alternative_cost_percentile_interpolation, + splitting_all_candidates, + merging_all_candidates, ) - if not cost_matrix is None: - _, xs, ys = lap_optimization(cost_matrix) + ###### remove segment connections if not associated with split / merge ###### - M = gap_closing_dist_matrix.shape[0] - N1 = splitting_dist_matrix.shape[1] - N2 = merging_dist_matrix.shape[1] + if self.remove_no_split_merge_links: + track_tree = _remove_no_split_merge_links( + track_tree.copy(), segment_connected_edges + ) + return track_tree - for ind, row in segments_df.iterrows(): - col_ind = xs[ind] - first_frame_index = (row["first_frame"], row["first_index"]) - last_frame_index = (row["last_frame"], row["last_index"]) - if col_ind < M: - target_frame_index = tuple( - segments_df.loc[col_ind, ["first_frame", "first_index"]] - ) - track_tree.add_edge(last_frame_index, target_frame_index) - elif col_ind < M + N2: - track_tree.add_edge( - last_frame_index, tuple(merging_all_candidates[col_ind - M]) - ) - row_ind = ys[ind] - if M <= row_ind and row_ind < M + N1: - track_tree.add_edge( - first_frame_index, tuple(splitting_all_candidates[row_ind - M]) - ) +def laptrack(coords: Sequence[FloatArray], **kwargs) -> nx.Graph: + """Track points by solving linear assignment problem. - return track_tree + Parameters + ---------- + coords : Sequence[FloatArray] + The list of coordinates of point for each frame. + The array index means (sample, dimension). + + **kwargs : dict + Parameters for the LapTrack initalization + + Returns + ------- + tracks : networkx.Graph + The graph for the tracks, whose nodes are (frame, index). + + """ + lt = LapTrack(**kwargs) + return lt.predict(coords) diff --git a/src/laptrack/_typing_utils.py b/src/laptrack/_typing_utils.py index ca878019..477a2ee2 100644 --- a/src/laptrack/_typing_utils.py +++ b/src/laptrack/_typing_utils.py @@ -2,8 +2,8 @@ import numpy as np import numpy.typing as npt +from scipy.sparse import coo_matrix from scipy.sparse import lil_matrix -from scipy.sparse.coo import coo_matrix NumArray = npt.NDArray[Union[np.float_, np.int_]] FloatArray = npt.NDArray[np.float_] diff --git a/tests/test_tracking.py b/tests/test_tracking.py index e376aa86..204a5404 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -4,8 +4,11 @@ import networkx as nx import numpy as np import pandas as pd +import pytest +from laptrack import LapTrack from laptrack import laptrack +from laptrack import LapTrackMulti DEFAULT_PARAMS = dict( track_dist_metric="sqeuclidean", @@ -57,59 +60,117 @@ ] # type: ignore -def test_reproducing_trackmate(shared_datadir: str) -> None: - for filename_suffix, params in FILENAME_SUFFIX_PARAMS: - filename = path.join(shared_datadir, f"trackmate_tracks_{filename_suffix}") - spots_df = pd.read_csv(filename + "_spots.csv") - frame_max = spots_df["frame"].max() - coords = [] - spot_ids = [] - for i in range(frame_max): - df = spots_df[spots_df["frame"] == i] - coords.append(df[["position_x", "position_y"]].values) - spot_ids.append(df["id"].values) - track_tree = laptrack(coords, **params) # type: ignore - - spot_id_to_coord_id = {} - for i, spot_ids_frame in enumerate(spot_ids): - for j, spot_id in enumerate(spot_ids_frame): - assert not spot_id in spot_id_to_coord_id - spot_id_to_coord_id[spot_id] = (i, j) - - edges_df = pd.read_csv(filename + "_edges.csv", index_col=0) - edges_df["coord_source_id"] = edges_df["spot_source_id"].map( - spot_id_to_coord_id - ) - edges_df["coord_target_id"] = edges_df["spot_target_id"].map( - spot_id_to_coord_id +@pytest.fixture(params=FILENAME_SUFFIX_PARAMS) +def testdata(request, shared_datadir: str): + filename_suffix, params = request.param + filename = path.join(shared_datadir, f"trackmate_tracks_{filename_suffix}") + spots_df = pd.read_csv(filename + "_spots.csv") + frame_max = spots_df["frame"].max() + coords = [] + spot_ids = [] + for i in range(frame_max): + df = spots_df[spots_df["frame"] == i] + coords.append(df[["position_x", "position_y"]].values) + spot_ids.append(df["id"].values) + + spot_id_to_coord_id = {} + for i, spot_ids_frame in enumerate(spot_ids): + for j, spot_id in enumerate(spot_ids_frame): + assert not spot_id in spot_id_to_coord_id + spot_id_to_coord_id[spot_id] = (i, j) + + edges_df = pd.read_csv(filename + "_edges.csv", index_col=0) + edges_df["coord_source_id"] = edges_df["spot_source_id"].map(spot_id_to_coord_id) + edges_df["coord_target_id"] = edges_df["spot_target_id"].map(spot_id_to_coord_id) + valid_edges_df = edges_df[~pd.isna(edges_df["coord_target_id"])] + edges_arr = valid_edges_df[["coord_source_id", "coord_target_id"]].values + edges_set = set(list(map(tuple, (edges_arr)))) + + return params, coords, edges_set + + +def test_reproducing_trackmate(testdata) -> None: + params, coords, edges_set = testdata + lt = LapTrack(**params) + track_tree = lt.predict(coords) + assert edges_set == set(track_tree.edges) + + +def test_multi_algorithm_reproducing_trackmate(testdata) -> None: + params, coords, edges_set = testdata + lt = LapTrackMulti(**params) + track_tree = lt.predict(coords) + assert edges_set == set(track_tree.edges) + + +@pytest.fixture(params=[2, 3, 4]) +def dist_metric(request): + if request.param == 2: + return lambda c1, c2: np.linalg.norm(c1 - c2) ** 2 + elif request.param == 3: + return lambda c1, c2, _1: np.linalg.norm(c1 - c2) ** 2 + elif request.param == 4: + return lambda c1, c2, _1, _2: np.linalg.norm(c1 - c2) ** 2 + + +def test_multi_algorithm_reproducing_trackmate_lambda(testdata, dist_metric) -> None: + params, coords, edges_set = testdata + params = params.copy() + params.update( + dict( + track_dist_metric=lambda c1, c2: np.linalg.norm(c1 - c2) ** 2, + splitting_dist_metric=dist_metric, + merging_dist_metric=dist_metric, ) - valid_edges_df = edges_df[~pd.isna(edges_df["coord_target_id"])] - edges_arr = valid_edges_df[["coord_source_id", "coord_target_id"]].values + ) + lt = LapTrackMulti(**params) + track_tree = lt.predict(coords) + assert edges_set == set(track_tree.edges) + + +def test_multi_algorithm_reproducing_trackmate_3_arg_lambda(testdata) -> None: + params, coords, edges_set = testdata + lt = LapTrackMulti(**params) + track_tree = lt.predict(coords) + assert edges_set == set(track_tree.edges) + + +def test_multi_algorithm_reproducing_trackmate_4_arg_lambda(testdata) -> None: + params, coords, edges_set = testdata + lt = LapTrackMulti(**params) + track_tree = lt.predict(coords) + assert edges_set == set(track_tree.edges) + - assert set(list(map(tuple, (edges_arr)))) == set(track_tree.edges) +def test_laptrack_function_shortcut(testdata) -> None: + params, coords, edges_set = testdata + lt = LapTrack(**params) + track_tree1 = lt.predict(coords) + track_tree2 = laptrack(coords, **params) + assert set(track_tree1.edges) == set(track_tree2.edges) def test_tracking_zero_distance() -> None: coords = [np.array([[10, 10], [12, 11]]), np.array([[10, 10], [13, 11]])] - track_tree = laptrack( - coords, + lt = LapTrack( gap_closing_cost_cutoff=False, splitting_cost_cutoff=False, merging_cost_cutoff=False, ) # type: ignore + track_tree = lt.predict(coords) edges = track_tree.edges() assert set(edges) == set([((0, 0), (1, 0)), ((0, 1), (1, 1))]) def test_tracking_not_connected() -> None: coords = [np.array([[10, 10], [12, 11]]), np.array([[50, 50], [53, 51]])] - track_tree = laptrack( - coords, + lt = LapTrack( track_cost_cutoff=15**2, gap_closing_cost_cutoff=False, splitting_cost_cutoff=False, merging_cost_cutoff=False, ) # type: ignore + track_tree = lt.predict(coords) edges = track_tree.edges() assert set(edges) == set() @@ -121,12 +182,12 @@ def test_gap_closing(shared_datadir: str) -> None: allow_pickle=True, ) ) - track_tree = laptrack( - coords, + lt = LapTrack( track_cost_cutoff=15**2, splitting_cost_cutoff=False, merging_cost_cutoff=False, ) # type: ignore + track_tree = lt.predict(coords) for track in nx.connected_components(track_tree): frames, _ = zip(*track) assert len(set(frames)) == len(frames) diff --git a/tests/test_tracking_routines.py b/tests/test_tracking_routines.py new file mode 100644 index 00000000..43b2c709 --- /dev/null +++ b/tests/test_tracking_routines.py @@ -0,0 +1,92 @@ +import networkx as nx +import numpy as np + +from laptrack._tracking import _get_segment_df +from laptrack._tracking import _remove_no_split_merge_links + + +def test_reproducing_trackmate() -> None: + test_tree = nx.Graph() + test_tree.add_edges_from( + [ + ((0, 0), (1, 0)), + ((1, 0), (2, 0)), + ((2, 0), (3, 0)), + ((3, 0), (4, 0)), + ((1, 0), (2, 1)), + ((2, 1), (3, 1)), + ] + ) + segment_connected_edges = [ + ((1, 0), (2, 0)), + ((2, 0), (3, 0)), + ((2, 1), (3, 1)), + ] + removed_edges = [ + ((2, 0), (3, 0)), + ((2, 1), (3, 1)), + ] + res_tree = _remove_no_split_merge_links(test_tree.copy(), segment_connected_edges) + assert set(test_tree.edges) - set(res_tree.edges) == set(removed_edges) + + +def test_get_segment_df() -> None: + test_tree = nx.Graph() + test_tree.add_edges_from( + [ + ((0, 0), (1, 0)), + ((1, 0), (2, 0)), + ((2, 0), (3, 0)), + ((3, 0), (4, 0)), + ((2, 1), (3, 1)), + ((2, 2), (3, 2)), + ((3, 2), (4, 2)), + ] + ) + + test_coords = [ + np.array([[1.0, 1.0]]), + np.array([[2.0, 2.0], [2.1, 2.1]]), + np.array([[3.0, 3.0], [3.1, 3.1], [3.2, 3.2]]), + np.array([[4.0, 4.0], [4.1, 4.1], [4.2, 4.2]]), + np.array([[5.0, 5.0], [5.1, 5.1], [5.2, 5.2]]), + ] + + expected_segment_dfs = [ + ({(0, 0), (1, 0), (2, 0), (3, 0), (4, 0)}, 0, 0, 4, 0, (1.0, 1.0), (5.0, 5.0)), + ( + { + (2, 1), + (3, 1), + }, + 2, + 1, + 3, + 1, + (3.1, 3.1), + (4.1, 4.1), + ), + ({(2, 2), (3, 2), (4, 2)}, 2, 2, 4, 2, (3.2, 3.2), (5.2, 5.2)), + ] + + segment_df = _get_segment_df(test_coords, test_tree) + + segment_df_set = [] + for i, row in segment_df.iterrows(): + segment_df_set.append( + ( + set(row["segment"]), + row["first_frame"], + row["first_index"], + row["last_frame"], + row["last_index"], + tuple(row["first_frame_coords"]), + tuple(row["last_frame_coords"]), + ) + ) + for s in segment_df_set: + matched = False + for s1 in expected_segment_dfs: + if s1 == s: + matched = True + assert matched