diff --git a/setup.cfg b/setup.cfg index b990699a..f977da71 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,9 +23,12 @@ install_requires = daisy @ git+https://github.com/funkelab/daisy funlib.learn.torch @ git+https://github.com/funkelab/funlib.learn.torch funlib.math @ git+https://github.com/funkelab/funlib.math + funlib.persistence @ git+https://github.com/funkelab/funlib.persistence + funlib.geometry @ git+https://github.com/funkelab/funlib.geometry funlib.show.neuroglancer @ git+https://github.com/funkelab/funlib.show.neuroglancer + funlib.evaluate @ git+https://github.com/funkelab/funlib.evaluate gunpowder @ git+https://github.com/rhoadesScholar/gunpowder@raygun - lsd @ git+https://github.com/funkelab/lsd@restructure + lsds @ git+https://github.com/funkelab/lsd waterz @ git+https://github.com/funkey/waterz matplotlib neuroglancer @@ -44,9 +47,9 @@ install_requires = jupyter ipywidgets webknossos - funlib.evaluate @ git+https://github.com/funkelab/funlib.evaluate pytest seaborn + logging [options.packages.find] where=src diff --git a/src/raygun/io/BaseDataPipe.py b/src/raygun/io/BaseDataPipe.py index 97ed2876..1e2f7845 100644 --- a/src/raygun/io/BaseDataPipe.py +++ b/src/raygun/io/BaseDataPipe.py @@ -28,9 +28,10 @@ def get_source(self, path, src_names, src_specs=None): def prenet_pipe(self, mode: str = "train"): # Make pre-net datapipe - prenet_pipe = self.source + gp.RandomLocation() + prenet_pipe = self.source if mode == "train": sections = [ + gp.RandomLocation(), "reject", "resample", "preprocess", @@ -41,7 +42,7 @@ def prenet_pipe(self, mode: str = "train"): elif mode == "predict": sections = ["reject", "resample", "preprocess", "unsqueeze", "stack"] elif mode == "test": - sections = ["reject", "resample", "preprocess", "unsqueeze", gp.Stack(1)] + sections = [gp.RandomLocation(), "reject", "resample", "preprocess", "unsqueeze", gp.Stack(1)] else: raise ValueError(f"mode={mode} not implemented.") diff --git a/src/raygun/webknossos_utils/wkw_seg_to_zarr.py b/src/raygun/webknossos_utils/wkw_seg_to_zarr.py index f6a12441..57d335ec 100644 --- a/src/raygun/webknossos_utils/wkw_seg_to_zarr.py +++ b/src/raygun/webknossos_utils/wkw_seg_to_zarr.py @@ -8,8 +8,15 @@ import tempfile from glob import glob import os +import logging +from funlib.persistence import open_ds, prepare_ds +from funlib.geometry import Roi, Coordinate +import numpy as np +from skimage.draw import line_nd +logger = logging.getLogger(__name__) + def download_wk_skeleton( annotation_ID, save_path, @@ -52,6 +59,90 @@ def download_wk_skeleton( return zip_path +def parse_skeleton(zip_path) -> dict: + fin = zip_path + if not fin.endswith(".zip"): + try: + fin = get_updated_skeleton(zip_path) + assert fin.endswith(".zip"), "Skeleton zip file not found." + except: + assert False, "CATMAID NOT IMPLEMENTED" + + wk_skels = wk.skeleton.Skeleton.load(fin) + # return wk_skels + skel_coor = {} + for tree in wk_skels.trees: + skel_coor[tree.id] = [] + for start, end in tree.edges.keys(): + start_pos = start.position.to_np() + end_pos = end.position.to_np() + skel_coor[tree.id].append([start_pos, end_pos]) + + return skel_coor + + +def get_updated_skeleton(zip_path) -> str: + if not os.path.exists(zip_path): + path = os.path.dirname(os.path.realpath(zip_path)) + search_path = os.path.join(path, "skeletons/*") + files = glob(search_path) + if len(files) == 0: + skel_file = download_wk_skeleton() + else: + skel_file = max(files, key=os.path.getctime) + skel_file = os.path.abspath(skel_file) + + return skel_file + +def rasterize_skeleton(zip_path:str, + raw_file:str, + raw_ds:str) -> np.ndarray: + + logger.info(f"Rasterizing skeleton...") + + skel_coor = parse_skeleton(zip_path) + + # Initialize rasterized skeleton image + raw = open_ds(raw_file, raw_ds) + + dataset_shape = raw.data.shape + print(dataset_shape) + voxel_size = raw.voxel_size + offset = raw.roi.begin # unhardcode for nonzero offset + image = np.zeros(dataset_shape, dtype=np.uint8) + + def adjust(coor): + ds_under = [x-1 for x in dataset_shape] + return np.min([coor - offset, ds_under], 0) + + print("adjusting . . .") + for id, tree in skel_coor.items(): + # iterates through ever node and assigns id to {image} + for start, end in tree: + line = line_nd(adjust(start), adjust(end)) + image[line] = id + + + # Save GT rasterization #TODO: implement daisy blockwise option + total_roi = Roi( + Coordinate(offset) * Coordinate(voxel_size), + Coordinate(dataset_shape) * Coordinate(voxel_size), + ) + + print("saving . . .") + out_ds = prepare_ds( + raw_file, + "volumes/training_rasters", + total_roi, + voxel_size, + image.dtype, + delete=True, + ) + out_ds[out_ds.roi] = image + + return image + + def get_wk_mask( annotation_ID, save_path, # TODO: Add mkdtemp() as default