Skip to content

Commit

Permalink
Merge pull request #6 from lincbrain/jp2_to_zarr_mul_slices
Browse files Browse the repository at this point in the history
jp2_to_zarr_mul_slices
  • Loading branch information
jingjingwu1225 authored Jul 18, 2024
2 parents 92ec42e + e44038e commit 1700119
Showing 1 changed file with 172 additions and 133 deletions.
305 changes: 172 additions & 133 deletions scripts/jp2_to_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,30 @@
import zarr
import ast
import numcodecs
import json
import uuid
import os
import math
import numpy as np
from glob import glob
import nibabel as nib
from typing import Optional

HOME = '/space/aspasia/2/users/linc/000003'

# Path to LincBrain dataset
LINCSET = os.path.join(HOME, 'sourcedata')
LINCOUT = os.path.join(HOME, 'rawdata')
app = cyclopts.App(help_format="markdown")


@app.default
def convert(
inp: str,
inp: str = None,
out: str = None,
subjects: list = [],
*,
chunk: int = 1024,
chunk: int = 4096,
compressor: str = 'blosc',
compressor_opt: str = "{}",
max_load: int = 16384,
Expand Down Expand Up @@ -89,142 +97,172 @@ def convert(
thickness
Slice thickness
"""
if not out:
out = os.path.splitext(inp)[0]
for LINCSUB in subjects:
print('working on subject', LINCSUB)
HISTO_FOLDER = os.path.join(LINCSET, f'sub-{LINCSUB}/micr')
OUT_FOLDER = os.path.join(LINCOUT, f'sub-{LINCSUB}/micr')
os.makedirs(OUT_FOLDER, exist_ok=True)
inp_dir = list(sorted(glob(os.path.join(HISTO_FOLDER, f'*nDF.jp2'))))

start_num, end_num = 0, len(inp_dir)-1
out = os.path.join(OUT_FOLDER, f'sub-{LINCSUB}_sample-slice{start_num:04d}slice{end_num:04d}_stain-LY_DF')
out += '.nii.zarr' if nii else '.ome.zarr'

nii = nii or out.endswith('.nii.zarr')

if isinstance(compressor_opt, str):
compressor_opt = ast.literal_eval(compressor_opt)

j2k = glymur.Jp2k(inp)
vxw, vxh = get_pixelsize(j2k)

# Prepare Zarr group
omz = zarr.storage.DirectoryStore(out)
omz = zarr.group(store=omz, overwrite=True)

# Prepare chunking options
opt = {
'chunks': list(j2k.shape[2:]) + [chunk, chunk],
'dimension_separator': r'/',
'order': 'F',
'dtype': np.dtype(j2k.dtype).str,
'fill_value': None,
'compressor': make_compressor(compressor, **compressor_opt),
}

# Write each level
nblevel = j2k.codestream.segment[2].num_res
has_channel = j2k.ndim - 2
for level in range(nblevel):
subdat = WrappedJ2K(j2k, level=level)
shape = subdat.shape
print('Convert level', level, 'with shape', shape)
omz.create_dataset(str(level), shape=shape, **opt)
array = omz[str(level)]
if max_load is None or (shape[-2] < max_load and shape[-1] < max_load):
array[...] = subdat[...]
else:
ni = ceildiv(shape[-2], max_load)
nj = ceildiv(shape[-1], max_load)
for i in range(ni):
for j in range(nj):
print(f'\r{i+1}/{ni}, {j+1}/{nj}', end='')
array[
...,
i*max_load:min((i+1)*max_load, shape[-2]),
j*max_load:min((j+1)*max_load, shape[-1]),
] = subdat[
...,
i*max_load:min((i+1)*max_load, shape[-2]),
j*max_load:min((j+1)*max_load, shape[-1]),
nii = nii or out.endswith('.nii.zarr')
print(out)

if isinstance(compressor_opt, str):
compressor_opt = ast.literal_eval(compressor_opt)

# Prepare Zarr group
omz = zarr.storage.DirectoryStore(out)
omz = zarr.group(store=omz, overwrite=True)

nblevel, has_channel, dtype_jp2 = float('inf'), float('inf'), ''
# get new_size
new_height, new_width = 0, 0
for inp in inp_dir:
jp2 = glymur.Jp2k(inp)
nblevel = min(nblevel, jp2.codestream.segment[2].num_res)
has_channel = min(has_channel, jp2.ndim - 2)
dtype_jp2 = np.dtype(jp2.dtype).str
if jp2.shape[0] > new_height:
new_height = jp2.shape[0]
if jp2.shape[1] > new_width:
new_width = jp2.shape[1]
new_size = (new_height, new_width, 3) if has_channel else (new_height, new_width)
print(len(inp_dir), new_size, nblevel, has_channel)


# Prepare chunking options
opt = {
'chunks': list(new_size[2:]) + [1] + [chunk, chunk],
'dimension_separator': r'/',
'order': 'F',
'dtype': dtype_jp2,
'fill_value': None,
'compressor': make_compressor(compressor, **compressor_opt),
}
print(opt)


# Write each level
for level in range(nblevel):
shape = [ceildiv(s, 2**level) for s in new_size[:2]]
shape = [new_size[2]] + [len(inp_dir)] + [s for s in shape]

omz.create_dataset(f'{level}', shape=shape, **opt)
array = omz[f'{level}']

# Write each slice
for idx, inp in enumerate(inp_dir):
j2k = glymur.Jp2k(inp)
vxw, vxh = get_pixelsize(j2k)
subdat = WrappedJ2K(j2k, level=level)
subdat_size = subdat.shape
print('Convert level', level, 'with shape', shape, 'for slice', idx, 'with size', subdat_size)

# offset while attaching
x, y = (int((shape[-2] - subdat_size[-2])/2), int((shape[-1] - subdat_size[-1])/2))

if max_load is None or (shape[-2] < max_load and shape[-1] < max_load):
array[..., idx, :, :] = np.zeros((3, shape[-2], shape[-1]), dtype = np.uint8)
array[..., idx, x : x + subdat_size[1], y : y + subdat_size[2]] = subdat[...]

else:
ni = ceildiv(shape[-2], max_load)
nj = ceildiv(shape[-1], max_load)

for i in range(ni):
for j in range(nj):
print(f'\r{i+1}/{ni}, {j+1}/{nj}', end=' ')
start_x, end_x = i*max_load, min((i+1)*max_load, shape[-2])
start_y, end_y = j*max_load, min((j+1)*max_load, shape[-1])
array[..., idx, start_x:end_x, start_y:end_y] = np.zeros((3, end_x-start_x, end_y-start_y), dtype = np.uint8)
if end_x <= x or end_y <= y:
continue

if start_x >= subdat_size[-2] or start_y >= subdat_size[-1]:
continue

array[
...,
idx,
x + start_x: x + min(end_x, subdat_size[-2]),
y + start_y: y + min(end_y, subdat_size[-1]),
] = subdat[
...,
start_x: min((i+1)*max_load, subdat_size[-2]),
start_y: min((j+1)*max_load, subdat_size[-1]),
]
print('')

# Write OME-Zarr multiscale metadata
print('Write metadata')
multiscales = [{
'version': '0.4',
'axes': [
{"name": "z", "type": "space", "unit": "micrometer"},
{"name": "y", "type": "distance", "unit": "micrometer"},
{"name": "x", "type": "space", "unit": "micrometer"}
],
'datasets': [],
'type': 'jpeg2000',
'name': '',
}]
if has_channel:
multiscales[0]['axes'].insert(0, {"name": "c", "type": "channel"})

for n in range(nblevel):
shape0 = omz['0'].shape[-2:]
shape = omz[str(n)].shape[-2:]
multiscales[0]['datasets'].append({})
level = multiscales[0]['datasets'][-1]
level["path"] = str(n)

# I assume that wavelet transforms end up aligning voxel edges
# across levels, so the effective scaling is the shape ratio,
# and there is a half voxel shift wrt to the "center of first voxel"
# frame
level["coordinateTransformations"] = [
{
"type": "scale",
"scale": [1.0] * has_channel + [
1.0,
(shape0[0]/shape[0])*vxh,
(shape0[1]/shape[1])*vxw,
]
print('')

# Write OME-Zarr multiscale metadata
print('Write metadata')
multiscales = [{
'version': '0.4',
'axes': [
{"name": "y", "type": "space", "unit": "micrometer"},
{"name": "x", "type": "space", "unit": "micrometer"}
],
'datasets': [],
'type': 'jpeg2000',
'name': '',
}]
if has_channel:
multiscales[0]['axes'].insert(0, {"name": "c", "type": "channel"})

for n in range(nblevel):
shape0 = omz['0'].shape[-2:]
shape = omz[str(n)].shape[-2:]
multiscales[0]['datasets'].append({})
level = multiscales[0]['datasets'][-1]
level["path"] = str(n)

# I assume that wavelet transforms end up aligning voxel edges
# across levels, so the effective scaling is the shape ratio,
# and there is a half voxel shift wrt to the "center of first voxel"
# frame
level["coordinateTransformations"] = [
{
"type": "scale",
"scale": [1.0] * has_channel + [
(shape0[0]/shape[0])*vxh,
(shape0[1]/shape[1])*vxw,
]
},
},
{
"type": "translation",
"translation": [0.0] * has_channel + [
0.0,
(shape0[0]/shape[0] - 1)*vxh*0.5,
(shape0[1]/shape[1] - 1)*vxw*0.5,
]
}
]
multiscales[0]["coordinateTransformations"] = [
{
"type": "translation",
"translation": [0.0] * has_channel + [
(shape0[0]/shape[0] - 1)*vxh*0.5,
(shape0[1]/shape[1] - 1)*vxw*0.5,
]
"scale": [1.0] * (3 + has_channel),
"type": "scale"
}
]
multiscales[0]["coordinateTransformations"] = [
{
"scale": [1.0] * (2 + has_channel),
"type": "scale"
}
]
omz.attrs["multiscales"] = multiscales

if not nii:
print('done.')
return

# Write NIfTI-Zarr header
# NOTE: we use nifti2 because dimensions typically do not fit in a short
# TODO: we do not write the json zattrs, but it should be added in
# once the nifti-zarr package is released
shape = list(reversed(omz['0'].shape))
if has_channel:
shape = shape[:2] + [1, 1] + shape[2:]
affine = orientation_to_affine(orientation, vxw, vxh, thickness or 1)
if center:
affine = center_affine(affine, shape[:2])
header = nib.Nifti2Header()
header.set_data_shape(shape)
header.set_data_dtype(omz['0'].dtype)
header.set_qform(affine)
header.set_sform(affine)
header.set_xyzt_units(nib.nifti1.unit_codes.code['micron'])
header.structarr['magic'] = b'nz2\0'
header = np.frombuffer(header.structarr.tobytes(), dtype='u1')
opt = {
'chunks': [len(header)],
'dimension_separator': r'/',
'order': 'F',
'dtype': '|u1',
'fill_value': None,
'compressor': None,
}
omz.create_dataset('nifti', data=header, shape=shape, **opt)
print('done.')
omz.attrs["multiscales"] = multiscales


# Write sidecar .json file
json_name = os.path.join(OUT_FOLDER, f'sub-{LINCSUB}_sample-slice{start_num:04d}slice{end_num:04d}_stain-LY_DF.json')
dic = {}
dic['PixelSize'] = json.dumps([vxw, vxh])
dic['PixelSizeUnits'] = 'um'
dic['SliceThickness'] = 1.2
dic['SliceThicknessUnits'] = 'mm'
dic['SampleStaining'] = 'LY'

with open(json_name, "w") as outfile:
json.dump(dic, outfile)
outfile.write('\n')



def orientation_ensure_3d(orientation):
Expand Down Expand Up @@ -401,3 +439,4 @@ def get_pixelsize(j2k):

if __name__ == "__main__":
app()

0 comments on commit 1700119

Please sign in to comment.