Skip to content

Commit

Permalink
Merge pull request #334 from yfukai/parallel_refactor
Browse files Browse the repository at this point in the history
parallel refactoring for ray
  • Loading branch information
yfukai authored Jun 16, 2023
2 parents 905b09a + 8289bec commit d3c817d
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/laptrack/_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import warnings
from enum import Enum
from functools import partial
from typing import Callable
from typing import cast
from typing import Dict
Expand Down Expand Up @@ -313,7 +314,7 @@ def _get_gap_closing_matrix(
"""
if self.gap_closing_cost_cutoff:

def to_gap_closing_candidates(row):
def to_gap_closing_candidates(row, segments_df):
# if the index is in force_end_indices, do not add to gap closing candidates
if (row["last_frame"], row["last_index"]) in force_end_nodes:
return [], []
Expand Down Expand Up @@ -357,7 +358,7 @@ def to_gap_closing_candidates(row):

if self.parallel_backend == ParallelBackend.serial:
segments_df["gap_closing_candidates"] = segments_df.apply(
to_gap_closing_candidates, axis=1
partial(to_gap_closing_candidates, segments_df=segments_df), axis=1
)
elif self.parallel_backend == ParallelBackend.ray:
try:
Expand All @@ -367,7 +368,11 @@ def to_gap_closing_candidates(row):
"Please install `ray` to use `ParallelBackend.ray`."
)
remote_func = ray.remote(to_gap_closing_candidates)
res = [remote_func.remote(row) for _, row in segments_df.iterrows()]
segments_df_id = ray.put(segments_df)
res = [
remote_func.remote(row, segments_df_id)
for _, row in segments_df.iterrows()
]
segments_df["gap_closing_candidates"] = ray.get(res)
else:
raise ValueError(f"Unknown parallel_backend {self.parallel_backend}. ")
Expand Down Expand Up @@ -401,7 +406,7 @@ def _get_splitting_merging_candidates(
):
if cutoff:

def to_candidates(row):
def to_candidates(row, coords):
# if the prefix is first, this means the row is the track start, and the target is the track end
other_frame = row[f"{prefix}_frame"] + (-1 if prefix == "first" else 1)
target_coord = row[f"{prefix}_frame_coords"]
Expand Down Expand Up @@ -444,7 +449,7 @@ def to_candidates(row):

if self.parallel_backend == ParallelBackend.serial:
segments_df[f"{prefix}_candidates"] = segments_df.apply(
to_candidates, axis=1
partial(to_candidates, coords=coords), axis=1
)
elif self.parallel_backend == ParallelBackend.ray:
try:
Expand All @@ -454,7 +459,11 @@ def to_candidates(row):
"Please install `ray` to use `ParallelBackend.ray`."
)
remote_func = ray.remote(to_candidates)
res = [remote_func.remote(row) for _, row in segments_df.iterrows()]
coords_id = ray.put(coords)
res = [
remote_func.remote(row, coords_id)
for _, row in segments_df.iterrows()
]
segments_df[f"{prefix}_candidates"] = ray.get(res)
else:
raise ValueError(f"Unknown parallel_backend {self.parallel_backend}. ")
Expand Down

0 comments on commit d3c817d

Please sign in to comment.