From 6ce3beefb504e08b4e6d3aa61a6c5b8cdfc98c97 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Mon, 12 Feb 2024 09:01:48 -0500 Subject: [PATCH] updating stitching to improve speed (#845) --- cellpose/io.py | 20 +++++++++++++++----- cellpose/utils.py | 4 ++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/cellpose/io.py b/cellpose/io.py index d4d1ec6e..3322e34e 100644 --- a/cellpose/io.py +++ b/cellpose/io.py @@ -196,6 +196,7 @@ def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False): igood &= imfile[-len(imf):]==imf if igood: imn.append(im) + image_names = imn # remove duplicates @@ -240,9 +241,15 @@ def get_label_files(image_names, mask_filter, imf=None): #elif os.path.exists(label_names[0] + '_seg.npy'): # io_logger.info('labels found as _seg.npy files, converting to tif') else: - raise ValueError('labels not provided with correct --mask_filter') + if not flow_names: + raise ValueError('labels not provided with correct --mask_filter') + else: + label_names = None if not all([os.path.exists(label) for label in label_names]): - raise ValueError('labels not provided for all images in train and/or test set') + if not flow_names: + raise ValueError('labels not provided for all images in train and/or test set') + else: + label_names = None return label_names, flow_names @@ -250,7 +257,7 @@ def get_label_files(image_names, mask_filter, imf=None): def load_images_labels(tdir, mask_filter='_masks', image_filter=None, look_one_level_down=False, unet=False): image_names = get_image_files(tdir, mask_filter, image_filter, look_one_level_down) nimg = len(image_names) - + # training data label_names, flow_names = get_label_files(image_names, mask_filter, imf=image_filter) @@ -258,11 +265,14 @@ def load_images_labels(tdir, mask_filter='_masks', image_filter=None, look_one_l labels = [] k = 0 for n in range(nimg): - if os.path.isfile(label_names[n]): + if os.path.isfile(label_names[n]) or os.path.isfile(flow_names[0]): + print(image_names[n]) image = imread(image_names[n]) - label = imread(label_names[n]) + if label_names is not None: + label = imread(label_names[n]) if not unet: if flow_names is not None and not unet: + print(flow_names[n]) flow = imread(flow_names[n]) if flow.shape[0]<4: label = np.concatenate((label[np.newaxis,:,:], flow), axis=0) diff --git a/cellpose/utils.py b/cellpose/utils.py index 86090807..5ac41279 100644 --- a/cellpose/utils.py +++ b/cellpose/utils.py @@ -3,7 +3,7 @@ """ import logging import os, warnings, time, tempfile, datetime, pathlib, shutil, subprocess -from tqdm import tqdm +from tqdm import tqdm, trange from urllib.request import urlopen from urllib.parse import urlparse import cv2 @@ -403,7 +403,7 @@ def stitch3D(masks, stitch_threshold=0.25): mmax = masks[0].max() empty = 0 - for i in range(len(masks)-1): + for i in trange(len(masks)-1): iou = metrics._intersection_over_union(masks[i+1], masks[i])[1:,1:] if not iou.size and empty == 0: masks[i+1] = masks[i+1]