Skip to content

Commit

Permalink
Merge pull request #22 from jingjingwu1225/webknossos-annotate
Browse files Browse the repository at this point in the history
Refactor of webknossos annotation converter
  • Loading branch information
calvinchai authored Nov 25, 2024
2 parents 8d22e32 + d9377f6 commit f91aeaf
Show file tree
Hide file tree
Showing 10 changed files with 405 additions and 8 deletions.
1 change: 1 addition & 0 deletions conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- zarr
- nibabel
- tifffile
- wkw
- tensorstore
- pytest
- ruff
Expand Down
5 changes: 3 additions & 2 deletions linc_convert/modalities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Converters for all imaging modalities."""

__all__ = ["df", "lsm", "psoct"]
from . import df, lsm, psoct
__all__ = ["df", "lsm", "wk", "psoct"]
from . import df, lsm, wk, psoct

4 changes: 4 additions & 0 deletions linc_convert/modalities/wk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Webknossos annotation converters."""

__all__ = ["cli", "webknossos_annotation"]
from . import cli, webknossos_annotation
9 changes: 9 additions & 0 deletions linc_convert/modalities/wk/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Entry-points for Webknossos annotation converter."""

from cyclopts import App

from linc_convert.cli import main

help = "Converters for Webknossos annotation"
wk = App(name="wk", help=help)
main.command(wk)
258 changes: 258 additions & 0 deletions linc_convert/modalities/wk/webknossos_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
"""Convert annotation downloaded from webknossos into ome.zarr format."""

# stdlib
import ast
import json
import os
import shutil

import cyclopts
import numpy as np

# externals
import wkw
import zarr

# internals
from linc_convert.modalities.wk.cli import wk
from linc_convert.utils.math import ceildiv
from linc_convert.utils.zarr import make_compressor

webknossos = cyclopts.App(name="webknossos", help_format="markdown")
wk.command(webknossos)


@webknossos.default
def convert(
wkw_dir: str = None,
ome_dir: str = None,
out: str = None,
dic: str = None,
*,
chunk: int = 1024,
compressor: str = "blosc",
compressor_opt: str = "{}",
max_load: int = 16384,
) -> None:
"""
Convert annotations (in .wkw format) from webknossos to ome.zarr format.
This script converts annotations from webknossos, following the czyx direction,
to the ome.zarr format.
The conversion ensures that the annotations match the underlying dataset.
Parameters
----------
wkw_dir : str
Path to the unzipped manual annotation folder downloaded from webknossos
in .wkw format. For example: .../annotation_folder/data_Volume.
ome_dir : str
Path to the underlying ome.zarr dataset, following the BIDS naming standard.
out : str
Path to the output directory for saving the converted ome.zarr.
The ome.zarr file name is generated automatically based on ome_dir
and the initials of the annotator.
dic : dict
A dictionary mapping annotation values to the following standard values
if annotation doesn't match the standard.
The dictionary should be in single quotes, with keys in double quotes,
for example: dic = '{"2": 1, "4": 2}'.
The standard values are:
- 0: background
- 1: Light Bundle
- 2: Moderate Bundle
- 3: Dense Bundle
- 4: Light Terminal
- 5: Moderate Terminal
- 6: Dense Terminal
- 7: Single Fiber
"""
dic = json.loads(dic)
dic = {int(key): int(value) for key, value in dic.items()}

# load underlying dataset info to get size info
omz_data = zarr.open_group(ome_dir, mode="r")
nblevel = len([i for i in os.listdir(ome_dir) if i.isdigit()])
wkw_dataset_path = os.path.join(wkw_dir, get_mask_name(nblevel - 1))
wkw_dataset = wkw.Dataset.open(wkw_dataset_path)

low_res_offsets = []
omz_res = omz_data[nblevel - 1]
n = omz_res.shape[1]
size = omz_res.shape[-2:]
for idx in range(n):
offset_x, offset_y = 0, 0
data = wkw_dataset.read(
off=(offset_y, offset_x, idx), shape=[size[1], size[0], 1]
)
data = data[0, :, :, 0]
data = np.transpose(data, (1, 0))
[t0, b0, l0, r0] = find_borders(data)
low_res_offsets.append([t0, b0, l0, r0])

# setup save info
basename = os.path.basename(ome_dir)[:-9]
initials = wkw_dir.split("/")[-2][:2]
out = os.path.join(out, basename + "_dsec_" + initials + ".ome.zarr")
if os.path.exists(out):
shutil.rmtree(out)
os.makedirs(out, exist_ok=True)

if isinstance(compressor_opt, str):
compressor_opt = ast.literal_eval(compressor_opt)

# Prepare Zarr group
store = zarr.storage.DirectoryStore(out)
omz = zarr.group(store=store, overwrite=True)

# Prepare chunking options
opt = {
"chunks": [1, 1] + [chunk, chunk],
"dimension_separator": r"/",
"order": "F",
"dtype": "uint8",
"fill_value": None,
"compressor": make_compressor(compressor, **compressor_opt),
}
print(opt)

# Write each level
for level in range(nblevel):
omz_res = omz_data[level]
size = omz_res.shape[-2:]
shape = [1, n] + [i for i in size]

wkw_dataset_path = os.path.join(wkw_dir, get_mask_name(level))
wkw_dataset = wkw.Dataset.open(wkw_dataset_path)

omz.create_dataset(f"{level}", shape=shape, **opt)
array = omz[f"{level}"]

# Write each slice
for idx in range(n):
if -1 in low_res_offsets[idx]:
array[0, idx, :1, :1] = np.zeros((1, 1), dtype=np.uint8)
continue

top, bottom, left, right = [
k * 2 ** (nblevel - level - 1) for k in low_res_offsets[idx]
]
height, width = size[0] - top - bottom, size[1] - left - right

data = wkw_dataset.read(off=(left, top, idx), shape=[width, height, 1])
data = data[0, :, :, 0]
data = np.transpose(data, (1, 0))
if dic:
data = np.array(
[
[dic[data[i][j]] for j in range(data.shape[1])]
for i in range(data.shape[0])
]
)
subdat_size = data.shape

print(
"Convert level",
level,
"with shape",
shape,
"and slice",
idx,
"with size",
subdat_size,
)
if max_load is None or (
subdat_size[-2] < max_load and subdat_size[-1] < max_load
):
array[
0, idx, top : top + subdat_size[-2], left : left + subdat_size[-1]
] = data[...]
else:
ni = ceildiv(subdat_size[-2], max_load)
nj = ceildiv(subdat_size[-1], max_load)

for i in range(ni):
for j in range(nj):
print(f"\r{i+1}/{ni}, {j+1}/{nj}", end=" ")
start_x, end_x = (i * max_load,)
min((i + 1) * max_load, subdat_size[-2])

start_y, end_y = (j * max_load,)
min((j + 1) * max_load, subdat_size[-1])
array[
0,
idx,
top + start_x : top + end_x,
left + start_y : left + end_y,
] = data[start_x:end_x, start_y:end_y]
print("")

# Write OME-Zarr multiscale metadata
print("Write metadata")
omz.attrs["multiscales"] = omz_data.attrs["multiscales"]


def get_mask_name(level: int) -> str:
"""
Return the name of the mask for a given resolution level.
Parameters
----------
level : int
The resolution level for which to return the mask name.
Returns
-------
str
The name of the mask for the given level.
"""
if level == 0:
return "1"
else:
return f"{2**level}-{2**level}-1"


def cal_distance(img: np.ndarray) -> int:
"""
Return the distance of non-zero values to the top border.
Parameters
----------
img : np.ndarray
The array to calculate distance of object inside to border
Returns
-------
int
The distance of non-zero to the top border
"""
m = img.shape[0]
for i in range(m):
cnt = np.sum(img[i, :])
if cnt > 0:
return i
return m


def find_borders(img: np.ndarray) -> np.ndarray:
"""
Return the distances of non-zero values to four borders.
Parameters
----------
img : np.ndarray
The array to calculate distance of object inside to border
Returns
-------
int
The distance of non-zero values to four borders
"""
if np.max(img) == 0:
return [-1, -1, -1, -1]
top = cal_distance(img)
bottom = cal_distance(img[::-1])
left = cal_distance(np.rot90(img, k=3))
right = cal_distance(np.rot90(img, k=1))

return [max(0, k - 1) for k in [top, bottom, left, right]]
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ numpy = "*"
nibabel = "*"
zarr = "^2.0.0"


[tool.poetry.group.df]
optional = true
[tool.poetry.group.df.dependencies]
Expand All @@ -45,6 +46,11 @@ optional = true
h5py = "*"
scipy = "*"

[tool.poetry.group.wk]
optional = true
[tool.poetry.group.wk.dependencies]
wkw = "*"

[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
Expand Down
19 changes: 19 additions & 0 deletions tests/data/generate_trusted_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import os
import tempfile
import zipfile
from pathlib import Path

import test_df
import test_lsm
import test_wk
import zarr

from linc_convert.modalities.df import multi_slice
from linc_convert.modalities.lsm import mosaic
from linc_convert.modalities.wk import webknossos_annotation

if __name__ == "__main__":
with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -27,3 +30,19 @@
output_zarr = os.path.join(tmp_dir, "output.zarr")
mosaic.convert(tmp_dir, output_zarr)
zarr.copy_all(zarr.open(output_zarr), zarr.open("data/lsm.zarr.zip", "w"))

with tempfile.TemporaryDirectory() as tmp_dir:
test_wk._write_test_data(tmp_dir)

tmp_dir = Path(tmp_dir)
wkw_dir = str(tmp_dir / "wkw")
ome_dir = str(tmp_dir / "ome")

basename = os.path.basename(ome_dir)[:-9]
initials = wkw_dir.split("/")[-2][:2]
output_zarr = os.path.join(
tmp_dir, basename + "_dsec_" + initials + ".ome.zarr"
)

webknossos_annotation.convert(wkw_dir, ome_dir, tmp_dir, "{}")
zarr.copy_all(zarr.open(output_zarr), zarr.open("data/wk.zarr.zip", "w"))
Binary file added tests/data/wk.zarr.zip
Binary file not shown.
7 changes: 1 addition & 6 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ def _cmp_zarr_archives(path1: str, path2: str) -> bool:

# Compare keys (dataset structure)
if zarr1.keys() != zarr2.keys():
print(list(zarr1.keys()))
print(list(zarr2.keys()))
print("keys mismatch")
return False
if zarr1.attrs != zarr2.attrs:
Expand All @@ -34,10 +32,7 @@ def _cmp_zarr_archives(path1: str, path2: str) -> bool:
array1 = zarr1[key][:]
array2 = zarr2[key][:]

# Check for equality of the arrays
if not np.array_equal(array1, array2):
print(f"Mismatch found in dataset: {key}")
return False
np.testing.assert_allclose(array1, array2)
if zarr1[key].attrs != zarr2[key].attrs:
print("attrs mismatch")
return False
Expand Down
Loading

0 comments on commit f91aeaf

Please sign in to comment.