|
| 1 | + |
| 2 | + |
| 3 | +import argschema |
| 4 | +import os |
| 5 | +from acpreprocessing.utils import io, convert |
| 6 | +import numpy as np |
| 7 | +from scipy.ndimage import gaussian_filter |
| 8 | +import scipy.ndimage as ndimage |
| 9 | +from argschema.fields import Str, Float |
| 10 | +import matplotlib.pyplot as plt |
| 11 | + |
| 12 | +#Adapted from Jun Wang's code |
| 13 | + |
| 14 | +example_input = { |
| 15 | + "input_filename": "/Users/sharmishtaas/Documents/data/axonal/M6data/Section_13/ex1_2_13.tif", |
| 16 | + "flatten_method": "top", |
| 17 | + "output_filename": "test.tif", |
| 18 | + "flip_back": False |
| 19 | +} |
| 20 | + |
| 21 | + |
| 22 | +def ndfilter(img,sig=3): |
| 23 | + img = ndimage.gaussian_filter(img, sigma=(sig, sig, 5), order=0) |
| 24 | + return img |
| 25 | + |
| 26 | +def flatten_bottom(img, bottom): |
| 27 | + """ |
| 28 | + shift the height of each pixel to align the bottom of the section |
| 29 | + :param img: 3d array |
| 30 | + :param bottom: 2d array, int, indices of bottom surface |
| 31 | + :return imgb: 3d img, same size as img, |
| 32 | + """ |
| 33 | + |
| 34 | + if len(img.shape) != 3: |
| 35 | + raise ValueError('input array should be 3d.') |
| 36 | + |
| 37 | + if bottom.shape != (img.shape[1], img.shape[2]): |
| 38 | + raise ValueError('the shape of top should be the same size as each plane in img.') |
| 39 | + |
| 40 | + imgb = np.zeros(img.shape, dtype=img.dtype) |
| 41 | + |
| 42 | + z, y, x = img.shape |
| 43 | + |
| 44 | + for yi in range(y): |
| 45 | + for xi in range(x): |
| 46 | + b = bottom[yi, xi] |
| 47 | + if b!= 0: |
| 48 | + col = img[:b, yi, xi] |
| 49 | + imgb[-len(col):, yi, xi] = col |
| 50 | + |
| 51 | + imgb = imgb[-np.amax(bottom):, :, :] |
| 52 | + |
| 53 | + return imgb |
| 54 | + |
| 55 | + |
| 56 | +def flatten_top(img, top): |
| 57 | + """ |
| 58 | + shift the height of each pixel to align the top of the section |
| 59 | + :param img: 3d array |
| 60 | + :param top: 2d array, int, indices of top surface |
| 61 | + :return imgft: 3d img, same size as img, |
| 62 | + """ |
| 63 | + |
| 64 | + if len(img.shape) != 3: |
| 65 | + raise ValueError('input array should be 3d.') |
| 66 | + |
| 67 | + if top.shape != (img.shape[1], img.shape[2]): |
| 68 | + raise ValueError('the shape of top should be the same size as each plane in img.') |
| 69 | + |
| 70 | + imgt = np.zeros(img.shape, dtype=img.dtype) |
| 71 | + |
| 72 | + z, y, x = img.shape |
| 73 | + |
| 74 | + for yi in range(y): |
| 75 | + for xi in range(x): |
| 76 | + t = top[yi, xi] |
| 77 | + col = img[t:, yi, xi] |
| 78 | + imgt[:len(col), yi, xi] = col |
| 79 | + |
| 80 | + #imgt = imgt[-(z-np.amin(top)):, :, :] |
| 81 | + return imgt |
| 82 | + |
| 83 | +def up_crossings(data, threshold=0): |
| 84 | + """ |
| 85 | + find the index where the data up cross the threshold. return the indices of all up crossings (the onset data point |
| 86 | + that is greater than threshold, 1d-array). The input data should be 1d array. |
| 87 | + """ |
| 88 | + if len(data.shape) != 1: |
| 89 | + raise ValueError('Input data should be 1-d array.') |
| 90 | + |
| 91 | + pos = data > threshold |
| 92 | + return (~pos[:-1] & pos[1:]).nonzero()[0] + 1 |
| 93 | + |
| 94 | +def find_surface(img, surface_thr, top_buffer = 0, bot_buffer = 30, is_plot=False): |
| 95 | + """ |
| 96 | + :param img: 3d array, ZYX, assume small z = top; large z = bottom |
| 97 | + :param surface_thr: [0, 1], threshold for detecting surface |
| 98 | + :return top: 2d array, same size as each plane in img, z index of top surface |
| 99 | + :return bot: 2d array, same size as each plane in img, z index of bottom surface |
| 100 | + """ |
| 101 | + |
| 102 | + if len(img.shape) != 3: |
| 103 | + raise ValueError('input array should be 3d.') |
| 104 | + |
| 105 | + z, y, x = img.shape |
| 106 | + |
| 107 | + top = np.zeros((y, x), dtype=np.int) |
| 108 | + bot = np.ones((y, x), dtype=np.int) * z |
| 109 | + |
| 110 | + if is_plot: |
| 111 | + f = plt.figure(figsize=(5, 5)) |
| 112 | + ax = f.add_subplot(111) |
| 113 | + |
| 114 | + for yi in range(y): |
| 115 | + for xi in range(x): |
| 116 | + curr_t = img[:, yi, xi] |
| 117 | + mx = curr_t.max() |
| 118 | + mn = curr_t.min() |
| 119 | + if mx != mn: |
| 120 | + curr_t = (curr_t - mn) / (mx - mn) |
| 121 | + |
| 122 | + if is_plot: |
| 123 | + if yi % 10 == 0 and xi % 10 == 0: |
| 124 | + ax.plot(range(len(curr_t)), curr_t, '-b', lw=0.5, alpha=0.1) |
| 125 | + |
| 126 | + if curr_t[0] < surface_thr: |
| 127 | + curr_top = up_crossings(curr_t, surface_thr) |
| 128 | + cur_top = cur_top + top_buffer |
| 129 | + if len(curr_top) != 0: |
| 130 | + top[yi, xi] = curr_top[0] |
| 131 | + |
| 132 | + if curr_t[-1] < surface_thr: |
| 133 | + curr_bot = down_crossings(curr_t, surface_thr) |
| 134 | + curr_bot = curr_bot+bot_buffer |
| 135 | + if len(curr_bot) != 0: |
| 136 | + bot[yi, xi] = curr_bot[-1] |
| 137 | + |
| 138 | + if is_plot: |
| 139 | + plt.show() |
| 140 | + |
| 141 | + return top, bot |
| 142 | + |
| 143 | +def down_crossings(data, threshold=0): |
| 144 | + """ |
| 145 | + find the index where the data down cross the threshold. return the indices of all down crossings (the onset data |
| 146 | + point that is less than threshold, 1d-array). The input data should be 1d array. |
| 147 | + """ |
| 148 | + if len(data.shape) != 1: |
| 149 | + raise ValueError('Input data should be 1-d array.') |
| 150 | + |
| 151 | + pos = data < threshold |
| 152 | + return (~pos[:-1] & pos[1:]).nonzero()[0] + 1 |
| 153 | + |
| 154 | + |
| 155 | +def flatten_both_sides(img, top, bottom): |
| 156 | + """ |
| 157 | + flatten both sides by interpolation |
| 158 | + :param img: 3d array |
| 159 | + :param top: 2d array, int, indices of top surface |
| 160 | + :param bottom: 2d array, int, indices of bottom surface |
| 161 | + :return imgtb: 3d img |
| 162 | + """ |
| 163 | + |
| 164 | + if len(img.shape) != 3: |
| 165 | + raise ValueError('input array should be 3d.') |
| 166 | + |
| 167 | + if bottom.shape != (img.shape[1], img.shape[2]): |
| 168 | + raise ValueError('the shape of top should be the same size as each plane in img.') |
| 169 | + |
| 170 | + if top.shape != (img.shape[1], img.shape[2]): |
| 171 | + raise ValueError('the shape of top should be the same size as each plane in img.') |
| 172 | + |
| 173 | + z, y, x = img.shape |
| 174 | + |
| 175 | + depths = bottom - top |
| 176 | + depth = int(np.median(depths.flat)) |
| 177 | + |
| 178 | + |
| 179 | + imgtb = np.zeros((depth, y, x), dtype=img.dtype) |
| 180 | + |
| 181 | + colz_tb = np.arange(depth) |
| 182 | + |
| 183 | + for yi in range(y): |
| 184 | + for xi in range(x): |
| 185 | + col = img[top[yi, xi]:bottom[yi, xi], yi, xi] |
| 186 | + colz = np.arange(len(col)) |
| 187 | + imgtb[:, yi, xi] = np.interp(x=colz_tb, xp=colz, fp=col) |
| 188 | + |
| 189 | + return imgtb |
| 190 | + |
| 191 | + |
| 192 | +class FlattenSchema(argschema.ArgSchema): |
| 193 | + input_filename = Str(required=True, description='Input File') |
| 194 | + flatten_method = Str(required=True, validator=marshmallow.validate.OneOf(["top", "bottom"]), description='Type of flattening (top, bottom, both)') |
| 195 | + output_filename = Str(required=True, description='Output File') |
| 196 | + threshold = Float(default=0.3, description = 'Threshold for finding surface') |
| 197 | + |
| 198 | + |
| 199 | +class Flatten(argschema.ArgSchemaParser): |
| 200 | + default_schema = FlattenSchema |
| 201 | + |
| 202 | + |
| 203 | + def run(self): |
| 204 | + thresh = self.args['threshold'] |
| 205 | + I = io.get_tiff_image(self.args['input_filename']) |
| 206 | + IM = convert.downsample_stack_volume(I, dsfactors = (2,4,4)) |
| 207 | + I_flip = np.rot90(IM,1,(0,2)) |
| 208 | + I_flip_smoothed = ndfilter(I_flip) |
| 209 | + |
| 210 | + top,bottom = find_surface(I_flip_smoothed, thresh, is_plot=False) |
| 211 | + |
| 212 | + if self.args['flatten_method'] == 'top': |
| 213 | + I_flat = flatten_top(I_flip,top) |
| 214 | + elif self.args['flatten_method'] == 'bottom': |
| 215 | + I_flat = flatten_bottom(I_flip,bottom) |
| 216 | + else: |
| 217 | + print("Please choose correct flattening method: top or bottom") |
| 218 | + sys.exit() |
| 219 | + |
| 220 | + if self.args['flip_back']: |
| 221 | + I_flat = np.rot90(I_flat,3,(0,2)) |
| 222 | + io.save_tiff_image(I_flat, self.args['output_filename']) |
| 223 | + |
| 224 | + |
| 225 | + |
| 226 | + |
| 227 | + |
| 228 | + |
| 229 | + |
| 230 | +if __name__ == '__main__': |
| 231 | + mod = Flatten(example_input) |
| 232 | + |
| 233 | + mod.run() |
0 commit comments