Skip to content

Commit

Permalink
Update webknossos_annotation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jingjingwu1225 authored Nov 6, 2024
1 parent 3dcfc6c commit ac6ba97
Showing 1 changed file with 26 additions and 136 deletions.
162 changes: 26 additions & 136 deletions scripts/webknossos_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,65 +23,56 @@
from scipy.ndimage import binary_fill_holes

"""
This script resave downloaded annotations from webknossos in ome.zarr format following direction czyx
which is the same with underlying dataset.
This script resave annotations from webknossos in ome.zarr format following czyx direction
which is the same as underlying dataset.
It calculates offset from low-res images and set offset for other resolution accordingly.
If the annotation is contour rather than mask, set is_contour as True else False.
wkw_dir is the path to unzipped annotation
jp2_dir is the path to a jpeg2000 image of same subject to get voxel size information
ome_dir is the path to underlying ome.zarr dataset
dst is the path to saving annotation mask
Input:
wkw_dir is the path to unzipped manual annotation folder, for example: .../annotation_folder/data_Volume
ome_dir is the path to underlying ome.zarr dataset
dst is the path to saving annotation mask
dic is the dictionary of mapping annotation value to standard value
"""

app = cyclopts.App(help_format="markdown")
@app.default
def convert(
wkw_dir: str = None,
jp2_dir: str = None,
ome_dir: str = None,
dst: str = None,
is_contour: bool = True,
dic: dict = None,
*,
chunk: int = 1024,
compressor: str = 'blosc',
compressor_opt: str = "{}",
max_load: int = 16384,
nii: bool = False,
has_channel: int = 1,
):
# load underlying dataset info to get size info
omz_data = zarr.open_group(ome_dir, mode='r')
wkw_dataset_path = os.path.join(wkw_dir, get_mask_name(8))
nblevel = len([i for i in os.listdir(ome_dir) if i.isdigit()])
wkw_dataset_path = os.path.join(wkw_dir, get_mask_name(nblevel-1))
wkw_dataset = wkw.Dataset.open(wkw_dataset_path)

low_res_offsets = []
omz_res = omz_data[8]
size = np.shape(omz_res)
size = [i for i in omz_res.shape[-2:]] + [3]
for idx in range(20):
omz_res = omz_data[nblevel-1]
n = omz_res.shape[1]
size = omz_res.shape[-2:]
for idx in range(n):
offset_x, offset_y = 0, 0
data = wkw_dataset.read(off = (offset_y, offset_x, idx), shape = [size[1], size[0], 1])
data = data[0, :, :, 0]
data = np.transpose(data, (1, 0))
[t,b,l,r] = find_borders(data)
low_res_offsets.append([t,b,l,r])

# load jp2 image to get voxel size info
j2k = glymur.Jp2k(jp2_dir)
vxw, vxh = get_pixelsize(j2k)


# setup save info
basename = os.path.basename(ome_dir)[:-9]
initials = wkw_dir.split('/')[-2][:2]
out = os.path.join(dst, basename + '_dsec_' + initials + '.ome.zarr')
print(out)
if os.path.exists(out):
shutil.rmtree(out)
os.makedirs(out)
os.makedirs(out, exist_ok=True)

if isinstance(compressor_opt, str):
compressor_opt = ast.literal_eval(compressor_opt)
Expand All @@ -91,10 +82,6 @@ def convert(
omz = zarr.group(store=store, overwrite=True)


dic_EB = {0:0, 1:1, 9:2, 2:3, 4:4, 10:5, 11:6, 8:7, 3:8}
dic_JS = {0:0, 4:2, 5:3, 9:4, 6:5, 2:6, 7:7, 3:8}
dic_JW = {0:0, 1:1, 2:2, 3:3, 4:4, 6:5, 7:6, 8:7, 9:8}

# Prepare chunking options
opt = {
'chunks': [1, 1] + [chunk, chunk],
Expand All @@ -107,12 +94,12 @@ def convert(
print(opt)


nblevel = 9

# Write each level
for level in range(nblevel):
omz_res = omz_data[level]
size = omz_res.shape[-2:]
shape = [1, 20] + [i for i in size]
shape = [1, n] + [i for i in size]

wkw_dataset_path = os.path.join(wkw_dir, get_mask_name(level))
wkw_dataset = wkw.Dataset.open(wkw_dataset_path)
Expand All @@ -121,32 +108,23 @@ def convert(
array = omz[f'{level}']

# Write each slice
for idx in range(20):
for idx in range(n):
if -1 in low_res_offsets[idx]:
continue

t, b, l, r = [k*2**(8-level) for k in low_res_offsets[idx]]
t, b, l, r = [k*2**(nblevel-level-1) for k in low_res_offsets[idx]]
height, width = size[0]-t-b, size[1]-l-r

data = wkw_dataset.read(off = (l, t, idx), shape = [width, height, 1])
data = data[0, :, :, 0]
data = np.transpose(data, (1, 0))
if is_contour and level > 3:
if initials == 'EB':
mapped_img = np.array([[dic_EB[data[i][j]] for j in range(data.shape[1])] for i in range(data.shape[0])])
elif initials == 'JW':
mapped_img = np.array([[dic_JW[data[i][j]] for j in range(data.shape[1])] for i in range(data.shape[0])])
elif initials == 'JS':
mapped_img = np.array([[dic_JS[data[i][j]] for j in range(data.shape[1])] for i in range(data.shape[0])])
subdat = generate_mask(mapped_img)

else:
subdat = data
subdat_size = subdat.shape
if dic:
data = np.array([[dic[data[i][j]] for j in range(data.shape[1])] for i in range(data.shape[0])])
subdat_size = data.shape

print('Convert level', level, 'with shape', shape, 'and slice', idx, 'with size', subdat_size)
if max_load is None or (subdat_size[-2] < max_load and subdat_size[-1] < max_load):
array[0, idx, t: t+subdat_size[-2], l: l+subdat_size[-1]] = subdat[...]
array[0, idx, t: t+subdat_size[-2], l: l+subdat_size[-1]] = data[...]
else:
ni = ceildiv(subdat_size[-2], max_load)
nj = ceildiv(subdat_size[-1], max_load)
Expand All @@ -156,59 +134,13 @@ def convert(
print(f'\r{i+1}/{ni}, {j+1}/{nj}', end=' ')
start_x, end_x = i*max_load, min((i+1)*max_load, subdat_size[-2])
start_y, end_y = j*max_load, min((j+1)*max_load, subdat_size[-1])
array[0, idx, t + start_x: t + end_x, l + start_y: l + end_y] = subdat[start_x: end_x, start_y: end_y]
array[0, idx, t + start_x: t + end_x, l + start_y: l + end_y] = data[start_x: end_x, start_y: end_y]
print('')


# Write OME-Zarr multiscale metadata
print('Write metadata')
multiscales = [{
'version': '0.4',
'axes': [
{"name": "z", "type": "space", "unit": "micrometer"},
{"name": "y", "type": "space", "unit": "micrometer"},
{"name": "x", "type": "space", "unit": "micrometer"}
],
'datasets': [],
'type': 'jpeg2000',
'name': '',
}]

if has_channel:
multiscales[0]['axes'].insert(0, {"name": "c", "type": "channel"})
for n in range(nblevel):
shape0 = omz['0'].shape[-2:]
shape = omz[str(n)].shape[-2:]
multiscales[0]['datasets'].append({})
level = multiscales[0]['datasets'][-1]
level["path"] = str(n)


level["coordinateTransformations"] = [
{
"type": "scale",
"scale": [1.0] * has_channel + [
1.0,
(shape0[0]/shape[0])*vxh,
(shape0[1]/shape[1])*vxw,
]
},
{
"type": "translation",
"translation": [0.0] * has_channel + [
0.0,
(shape0[0]/shape[0] - 1)*vxh*0.5,
(shape0[1]/shape[1] - 1)*vxw*0.5,
]
}
]
multiscales[0]["coordinateTransformations"] = [
{
"scale": [1.0] * (3 + has_channel),
"type": "scale"
}
]
omz.attrs["multiscales"] = multiscales
omz.attrs["multiscales"] = omz_data.attrs["multiscales"]



Expand Down Expand Up @@ -240,28 +172,7 @@ def find_borders(img):
l = cal_distance(np.rot90(img, k=3))
r = cal_distance(np.rot90(img, k=1))

return [t, b, l, r]


def contour_to_mask(mask, value):
h, w = mask.shape[:2]
kernel = np.ones((5, 5), np.uint8)
closed_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
single_mask = binary_fill_holes(closed_mask)

single_mask = np.where(single_mask, value, 0)
return single_mask


def generate_mask(mask):
final_mask = np.zeros_like(mask).astype(np.uint8)
for value in range(1, 10):
if value not in mask:
continue
binary_mask = np.where(mask == value, 255, 0).astype(np.uint8)
single_mask = contour_to_mask(binary_mask, value)
final_mask = np.where(final_mask < single_mask, single_mask, final_mask)
return final_mask
return [max(0, k-1) for k in [t, b, l, r]]


def make_compressor(name, **prm):
Expand All @@ -277,27 +188,6 @@ def make_compressor(name, **prm):
return Compressor(**prm)


def get_pixelsize(j2k):
# Adobe XMP metadata
# https://en.wikipedia.org/wiki/Extensible_Metadata_Platform
XMP_UUID = 'BE7ACFCB97A942E89C71999491E3AFAC'
TAG_Images = '{http://ns.adobe.com/xap/1.0/}Images'
Tag_Desc = '{http://www.w3.org/1999/02/22-rdf-syntax-ns#}Description'
Tag_PixelWidth = '{http://ns.adobe.com/xap/1.0/}PixelWidth'
Tag_PixelHeight = '{http://ns.adobe.com/xap/1.0/}PixelHeight'

vxw = vxh = 1.0
for box in j2k.box:
if getattr(box, 'uuid', None) == uuid.UUID(XMP_UUID):
try:
images = list(box.data.iter(TAG_Images))[0]
desc = list(images.iter(Tag_Desc))[0]
vxw = float(desc.attrib[Tag_PixelWidth])
vxh = float(desc.attrib[Tag_PixelHeight])
except Exception:
pass
return vxw, vxh


if __name__ == "__main__":
app()

0 comments on commit ac6ba97

Please sign in to comment.