Skip to content

Commit

Permalink
Add a loop so the ome_zarr saves multi slices, following directory of…
Browse files Browse the repository at this point in the history
… resolution_level/RGB(0)/slice_number/chunk_folder/chunk_files
  • Loading branch information
Jingjing Wu committed Jul 17, 2024
1 parent 92ec42e commit e44038e
Showing 1 changed file with 172 additions and 133 deletions.
305 changes: 172 additions & 133 deletions scripts/jp2_to_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,30 @@
import zarr
import ast
import numcodecs
import json
import uuid
import os
import math
import numpy as np
from glob import glob
import nibabel as nib
from typing import Optional

HOME = '/space/aspasia/2/users/linc/000003'

# Path to LincBrain dataset
LINCSET = os.path.join(HOME, 'sourcedata')
LINCOUT = os.path.join(HOME, 'rawdata')
app = cyclopts.App(help_format="markdown")


@app.default
def convert(
inp: str,
inp: str = None,
out: str = None,
subjects: list = [],
*,
chunk: int = 1024,
chunk: int = 4096,
compressor: str = 'blosc',
compressor_opt: str = "{}",
max_load: int = 16384,
Expand Down Expand Up @@ -89,142 +97,172 @@ def convert(
thickness
Slice thickness
"""
if not out:
out = os.path.splitext(inp)[0]
for LINCSUB in subjects:
print('working on subject', LINCSUB)
HISTO_FOLDER = os.path.join(LINCSET, f'sub-{LINCSUB}/micr')
OUT_FOLDER = os.path.join(LINCOUT, f'sub-{LINCSUB}/micr')
os.makedirs(OUT_FOLDER, exist_ok=True)
inp_dir = list(sorted(glob(os.path.join(HISTO_FOLDER, f'*nDF.jp2'))))

start_num, end_num = 0, len(inp_dir)-1
out = os.path.join(OUT_FOLDER, f'sub-{LINCSUB}_sample-slice{start_num:04d}slice{end_num:04d}_stain-LY_DF')
out += '.nii.zarr' if nii else '.ome.zarr'

nii = nii or out.endswith('.nii.zarr')

if isinstance(compressor_opt, str):
compressor_opt = ast.literal_eval(compressor_opt)

j2k = glymur.Jp2k(inp)
vxw, vxh = get_pixelsize(j2k)

# Prepare Zarr group
omz = zarr.storage.DirectoryStore(out)
omz = zarr.group(store=omz, overwrite=True)

# Prepare chunking options
opt = {
'chunks': list(j2k.shape[2:]) + [chunk, chunk],
'dimension_separator': r'/',
'order': 'F',
'dtype': np.dtype(j2k.dtype).str,
'fill_value': None,
'compressor': make_compressor(compressor, **compressor_opt),
}

# Write each level
nblevel = j2k.codestream.segment[2].num_res
has_channel = j2k.ndim - 2
for level in range(nblevel):
subdat = WrappedJ2K(j2k, level=level)
shape = subdat.shape
print('Convert level', level, 'with shape', shape)
omz.create_dataset(str(level), shape=shape, **opt)
array = omz[str(level)]
if max_load is None or (shape[-2] < max_load and shape[-1] < max_load):
array[...] = subdat[...]
else:
ni = ceildiv(shape[-2], max_load)
nj = ceildiv(shape[-1], max_load)
for i in range(ni):
for j in range(nj):
print(f'\r{i+1}/{ni}, {j+1}/{nj}', end='')
array[
...,
i*max_load:min((i+1)*max_load, shape[-2]),
j*max_load:min((j+1)*max_load, shape[-1]),
] = subdat[
...,
i*max_load:min((i+1)*max_load, shape[-2]),
j*max_load:min((j+1)*max_load, shape[-1]),
nii = nii or out.endswith('.nii.zarr')
print(out)

if isinstance(compressor_opt, str):
compressor_opt = ast.literal_eval(compressor_opt)

# Prepare Zarr group
omz = zarr.storage.DirectoryStore(out)
omz = zarr.group(store=omz, overwrite=True)

nblevel, has_channel, dtype_jp2 = float('inf'), float('inf'), ''
# get new_size
new_height, new_width = 0, 0
for inp in inp_dir:
jp2 = glymur.Jp2k(inp)
nblevel = min(nblevel, jp2.codestream.segment[2].num_res)
has_channel = min(has_channel, jp2.ndim - 2)
dtype_jp2 = np.dtype(jp2.dtype).str
if jp2.shape[0] > new_height:
new_height = jp2.shape[0]
if jp2.shape[1] > new_width:
new_width = jp2.shape[1]
new_size = (new_height, new_width, 3) if has_channel else (new_height, new_width)
print(len(inp_dir), new_size, nblevel, has_channel)


# Prepare chunking options
opt = {
'chunks': list(new_size[2:]) + [1] + [chunk, chunk],
'dimension_separator': r'/',
'order': 'F',
'dtype': dtype_jp2,
'fill_value': None,
'compressor': make_compressor(compressor, **compressor_opt),
}
print(opt)


# Write each level
for level in range(nblevel):
shape = [ceildiv(s, 2**level) for s in new_size[:2]]
shape = [new_size[2]] + [len(inp_dir)] + [s for s in shape]

omz.create_dataset(f'{level}', shape=shape, **opt)
array = omz[f'{level}']

# Write each slice
for idx, inp in enumerate(inp_dir):
j2k = glymur.Jp2k(inp)
vxw, vxh = get_pixelsize(j2k)
subdat = WrappedJ2K(j2k, level=level)
subdat_size = subdat.shape
print('Convert level', level, 'with shape', shape, 'for slice', idx, 'with size', subdat_size)

# offset while attaching
x, y = (int((shape[-2] - subdat_size[-2])/2), int((shape[-1] - subdat_size[-1])/2))

if max_load is None or (shape[-2] < max_load and shape[-1] < max_load):
array[..., idx, :, :] = np.zeros((3, shape[-2], shape[-1]), dtype = np.uint8)
array[..., idx, x : x + subdat_size[1], y : y + subdat_size[2]] = subdat[...]

else:
ni = ceildiv(shape[-2], max_load)
nj = ceildiv(shape[-1], max_load)

for i in range(ni):
for j in range(nj):
print(f'\r{i+1}/{ni}, {j+1}/{nj}', end=' ')
start_x, end_x = i*max_load, min((i+1)*max_load, shape[-2])
start_y, end_y = j*max_load, min((j+1)*max_load, shape[-1])
array[..., idx, start_x:end_x, start_y:end_y] = np.zeros((3, end_x-start_x, end_y-start_y), dtype = np.uint8)
if end_x <= x or end_y <= y:
continue

if start_x >= subdat_size[-2] or start_y >= subdat_size[-1]:
continue

array[
...,
idx,
x + start_x: x + min(end_x, subdat_size[-2]),
y + start_y: y + min(end_y, subdat_size[-1]),
] = subdat[
...,
start_x: min((i+1)*max_load, subdat_size[-2]),
start_y: min((j+1)*max_load, subdat_size[-1]),
]
print('')

# Write OME-Zarr multiscale metadata
print('Write metadata')
multiscales = [{
'version': '0.4',
'axes': [
{"name": "z", "type": "space", "unit": "micrometer"},
{"name": "y", "type": "distance", "unit": "micrometer"},
{"name": "x", "type": "space", "unit": "micrometer"}
],
'datasets': [],
'type': 'jpeg2000',
'name': '',
}]
if has_channel:
multiscales[0]['axes'].insert(0, {"name": "c", "type": "channel"})

for n in range(nblevel):
shape0 = omz['0'].shape[-2:]
shape = omz[str(n)].shape[-2:]
multiscales[0]['datasets'].append({})
level = multiscales[0]['datasets'][-1]
level["path"] = str(n)

# I assume that wavelet transforms end up aligning voxel edges
# across levels, so the effective scaling is the shape ratio,
# and there is a half voxel shift wrt to the "center of first voxel"
# frame
level["coordinateTransformations"] = [
{
"type": "scale",
"scale": [1.0] * has_channel + [
1.0,
(shape0[0]/shape[0])*vxh,
(shape0[1]/shape[1])*vxw,
]
print('')

# Write OME-Zarr multiscale metadata
print('Write metadata')
multiscales = [{
'version': '0.4',
'axes': [
{"name": "y", "type": "space", "unit": "micrometer"},
{"name": "x", "type": "space", "unit": "micrometer"}
],
'datasets': [],
'type': 'jpeg2000',
'name': '',
}]
if has_channel:
multiscales[0]['axes'].insert(0, {"name": "c", "type": "channel"})

for n in range(nblevel):
shape0 = omz['0'].shape[-2:]
shape = omz[str(n)].shape[-2:]
multiscales[0]['datasets'].append({})
level = multiscales[0]['datasets'][-1]
level["path"] = str(n)

# I assume that wavelet transforms end up aligning voxel edges
# across levels, so the effective scaling is the shape ratio,
# and there is a half voxel shift wrt to the "center of first voxel"
# frame
level["coordinateTransformations"] = [
{
"type": "scale",
"scale": [1.0] * has_channel + [
(shape0[0]/shape[0])*vxh,
(shape0[1]/shape[1])*vxw,
]
},
},
{
"type": "translation",
"translation": [0.0] * has_channel + [
0.0,
(shape0[0]/shape[0] - 1)*vxh*0.5,
(shape0[1]/shape[1] - 1)*vxw*0.5,
]
}
]
multiscales[0]["coordinateTransformations"] = [
{
"type": "translation",
"translation": [0.0] * has_channel + [
(shape0[0]/shape[0] - 1)*vxh*0.5,
(shape0[1]/shape[1] - 1)*vxw*0.5,
]
"scale": [1.0] * (3 + has_channel),
"type": "scale"
}
]
multiscales[0]["coordinateTransformations"] = [
{
"scale": [1.0] * (2 + has_channel),
"type": "scale"
}
]
omz.attrs["multiscales"] = multiscales

if not nii:
print('done.')
return

# Write NIfTI-Zarr header
# NOTE: we use nifti2 because dimensions typically do not fit in a short
# TODO: we do not write the json zattrs, but it should be added in
# once the nifti-zarr package is released
shape = list(reversed(omz['0'].shape))
if has_channel:
shape = shape[:2] + [1, 1] + shape[2:]
affine = orientation_to_affine(orientation, vxw, vxh, thickness or 1)
if center:
affine = center_affine(affine, shape[:2])
header = nib.Nifti2Header()
header.set_data_shape(shape)
header.set_data_dtype(omz['0'].dtype)
header.set_qform(affine)
header.set_sform(affine)
header.set_xyzt_units(nib.nifti1.unit_codes.code['micron'])
header.structarr['magic'] = b'nz2\0'
header = np.frombuffer(header.structarr.tobytes(), dtype='u1')
opt = {
'chunks': [len(header)],
'dimension_separator': r'/',
'order': 'F',
'dtype': '|u1',
'fill_value': None,
'compressor': None,
}
omz.create_dataset('nifti', data=header, shape=shape, **opt)
print('done.')
omz.attrs["multiscales"] = multiscales


# Write sidecar .json file
json_name = os.path.join(OUT_FOLDER, f'sub-{LINCSUB}_sample-slice{start_num:04d}slice{end_num:04d}_stain-LY_DF.json')
dic = {}
dic['PixelSize'] = json.dumps([vxw, vxh])
dic['PixelSizeUnits'] = 'um'
dic['SliceThickness'] = 1.2
dic['SliceThicknessUnits'] = 'mm'
dic['SampleStaining'] = 'LY'

with open(json_name, "w") as outfile:
json.dump(dic, outfile)
outfile.write('\n')



def orientation_ensure_3d(orientation):
Expand Down Expand Up @@ -401,3 +439,4 @@ def get_pixelsize(j2k):

if __name__ == "__main__":
app()

0 comments on commit e44038e

Please sign in to comment.