Skip to content

Commit

Permalink
Merge pull request #138 from MannLabs/dev_mp
Browse files Browse the repository at this point in the history
Fix Out of Bounds issue upon multiple multiprocessing worker runs
  • Loading branch information
sophiamaedler authored Jan 18, 2025
2 parents 622a45f + 43f4a98 commit fffa9f4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/scportrait/pipeline/segmentation/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,6 @@ def _resolve_sharding(self, sharding_plan):
local_hf = h5py.File(local_output, "r")
local_hdf_labels = local_hf.get(self.DEFAULT_MASK_NAME)[:]

print(type(local_hdf_labels))
shifted_map, edge_labels = shift_labels(
local_hdf_labels,
class_id_shift,
Expand Down Expand Up @@ -902,8 +901,9 @@ def _resolve_sharding(self, sharding_plan):
if not self.deep_debug:
self._cleanup_shards(sharding_plan)

def _initializer_function(self, gpu_id_list):
def _initializer_function(self, gpu_id_list, n_processes):
current_process().gpu_id_list = gpu_id_list
current_process().n_processes = n_processes

def _perform_segmentation(self, shard_list):
# get GPU status
Expand All @@ -921,7 +921,7 @@ def _perform_segmentation(self, shard_list):
with mp.get_context(self.context).Pool(
processes=self.n_processes,
initializer=self._initializer_function,
initargs=[self.gpu_id_list],
initargs=[self.gpu_id_list, self.n_processes],
) as pool:
list(
tqdm(
Expand Down
4 changes: 4 additions & 0 deletions src/scportrait/pipeline/segmentation/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from skimage.filters import median
from skimage.morphology import binary_erosion, dilation, disk, erosion
from skimage.segmentation import watershed
import _multiprocessing

from scportrait.pipeline._utils.segmentation import (
contact_filter,
Expand Down Expand Up @@ -1353,6 +1354,9 @@ def _check_gpu_status(self):
gpu_id_list = current.gpu_id_list
cpu_id = int(cpu_name[cpu_name.find("-") + 1 :]) - 1

if cpu_id >= len(gpu_id_list):
cpu_id = cpu_id%current.n_processes

# track gpu_id and update GPU status
self.gpu_id = gpu_id_list[cpu_id]
self.status = "multi_GPU"
Expand Down

0 comments on commit fffa9f4

Please sign in to comment.