diff --git a/src/laptrack/data_conversion.py b/src/laptrack/data_conversion.py index 0dabdca7..c08a352d 100644 --- a/src/laptrack/data_conversion.py +++ b/src/laptrack/data_conversion.py @@ -5,8 +5,11 @@ import numpy as np import pandas as pd +from ._typing_utils import Int from ._typing_utils import NumArray +IntTuple = Tuple[Int, Int] + def convert_dataframe_to_coords( df: pd.DataFrame, @@ -88,8 +91,8 @@ def convert_tree_to_dataframe( # tree.nodes[(frame, index)]["tree_id"] = track_id tree2 = tree.copy() - splits = [] - merges = [] + splits: List[Tuple[IntTuple, List[IntTuple]]] = [] + merges: List[Tuple[IntTuple, List[IntTuple]]] = [] for node in tree.nodes: frame0, _index0 = node neighbors = list(tree.neighbors(node)) @@ -101,13 +104,13 @@ def convert_tree_to_dataframe( if tree2.has_edge(node, child): tree2.remove_edge(node, child) if node not in [p[0] for p in splits]: - splits.append([node, children]) + splits.append((node, children)) if len(parents) > 1: for parent in parents: if tree2.has_edge(node, parent): tree2.remove_edge(node, parent) if node not in [p[0] for p in merges]: - merges.append([node, parents]) + merges.append((node, parents)) connected_components = list(nx.connected_components(tree2)) for track_id, nodes in enumerate(connected_components):