Skip to content

Commit

Permalink
Merge branch 'main' into docs
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher authored Oct 12, 2023
2 parents 9d914f8 + 57da521 commit dd42ad3
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 4 deletions.
7 changes: 5 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ install_requires =
daisy @ git+https://github.com/funkelab/daisy
funlib.learn.torch @ git+https://github.com/funkelab/funlib.learn.torch
funlib.math @ git+https://github.com/funkelab/funlib.math
funlib.persistence @ git+https://github.com/funkelab/funlib.persistence
funlib.geometry @ git+https://github.com/funkelab/funlib.geometry
funlib.show.neuroglancer @ git+https://github.com/funkelab/funlib.show.neuroglancer
funlib.evaluate @ git+https://github.com/funkelab/funlib.evaluate
gunpowder @ git+https://github.com/rhoadesScholar/gunpowder@raygun
lsd @ git+https://github.com/funkelab/lsd@restructure
lsds @ git+https://github.com/funkelab/lsd
waterz @ git+https://github.com/funkey/waterz
matplotlib
neuroglancer
Expand All @@ -44,9 +47,9 @@ install_requires =
jupyter
ipywidgets
webknossos
funlib.evaluate @ git+https://github.com/funkelab/funlib.evaluate
pytest
seaborn
logging

[options.packages.find]
where=src
Expand Down
5 changes: 3 additions & 2 deletions src/raygun/io/BaseDataPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ def get_source(self, path, src_names, src_specs=None):

def prenet_pipe(self, mode: str = "train"):
# Make pre-net datapipe
prenet_pipe = self.source + gp.RandomLocation()
prenet_pipe = self.source
if mode == "train":
sections = [
gp.RandomLocation(),
"reject",
"resample",
"preprocess",
Expand All @@ -41,7 +42,7 @@ def prenet_pipe(self, mode: str = "train"):
elif mode == "predict":
sections = ["reject", "resample", "preprocess", "unsqueeze", "stack"]
elif mode == "test":
sections = ["reject", "resample", "preprocess", "unsqueeze", gp.Stack(1)]
sections = [gp.RandomLocation(), "reject", "resample", "preprocess", "unsqueeze", gp.Stack(1)]
else:
raise ValueError(f"mode={mode} not implemented.")

Expand Down
91 changes: 91 additions & 0 deletions src/raygun/webknossos_utils/wkw_seg_to_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
import tempfile
from glob import glob
import os
import logging
from funlib.persistence import open_ds, prepare_ds
from funlib.geometry import Roi, Coordinate
import numpy as np
from skimage.draw import line_nd


logger = logging.getLogger(__name__)

def download_wk_skeleton(
annotation_ID,
save_path,
Expand Down Expand Up @@ -52,6 +59,90 @@ def download_wk_skeleton(
return zip_path


def parse_skeleton(zip_path) -> dict:
fin = zip_path
if not fin.endswith(".zip"):
try:
fin = get_updated_skeleton(zip_path)
assert fin.endswith(".zip"), "Skeleton zip file not found."
except:
assert False, "CATMAID NOT IMPLEMENTED"

wk_skels = wk.skeleton.Skeleton.load(fin)
# return wk_skels
skel_coor = {}
for tree in wk_skels.trees:
skel_coor[tree.id] = []
for start, end in tree.edges.keys():
start_pos = start.position.to_np()
end_pos = end.position.to_np()
skel_coor[tree.id].append([start_pos, end_pos])

return skel_coor


def get_updated_skeleton(zip_path) -> str:
if not os.path.exists(zip_path):
path = os.path.dirname(os.path.realpath(zip_path))
search_path = os.path.join(path, "skeletons/*")
files = glob(search_path)
if len(files) == 0:
skel_file = download_wk_skeleton()
else:
skel_file = max(files, key=os.path.getctime)
skel_file = os.path.abspath(skel_file)

return skel_file

def rasterize_skeleton(zip_path:str,
raw_file:str,
raw_ds:str) -> np.ndarray:

logger.info(f"Rasterizing skeleton...")

skel_coor = parse_skeleton(zip_path)

# Initialize rasterized skeleton image
raw = open_ds(raw_file, raw_ds)

dataset_shape = raw.data.shape
print(dataset_shape)
voxel_size = raw.voxel_size
offset = raw.roi.begin # unhardcode for nonzero offset
image = np.zeros(dataset_shape, dtype=np.uint8)

def adjust(coor):
ds_under = [x-1 for x in dataset_shape]
return np.min([coor - offset, ds_under], 0)

print("adjusting . . .")
for id, tree in skel_coor.items():
# iterates through ever node and assigns id to {image}
for start, end in tree:
line = line_nd(adjust(start), adjust(end))
image[line] = id


# Save GT rasterization #TODO: implement daisy blockwise option
total_roi = Roi(
Coordinate(offset) * Coordinate(voxel_size),
Coordinate(dataset_shape) * Coordinate(voxel_size),
)

print("saving . . .")
out_ds = prepare_ds(
raw_file,
"volumes/training_rasters",
total_roi,
voxel_size,
image.dtype,
delete=True,
)
out_ds[out_ds.roi] = image

return image


def get_wk_mask(
annotation_ID,
save_path, # TODO: Add mkdtemp() as default
Expand Down

0 comments on commit dd42ad3

Please sign in to comment.