-
Notifications
You must be signed in to change notification settings - Fork 2
/
generate_borders_lits.py
30 lines (27 loc) · 1.3 KB
/
generate_borders_lits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# This script generates volumes including the border regions of given MRI labels (LITS)
import os
import nibabel
import torch
import torch.nn.functional as F
import numpy as np
from scipy.ndimage import sobel, generic_gradient_magnitude
from tqdm import tqdm
paths = ["LITS Challenge/TrainingData", "LITS Challenge/TestData"]
for path in paths:
# Init progress bar
progress_bar = tqdm(total=len(os.listdir(path=os.path.join(path, "segmentation"))))
for index, file in enumerate(os.listdir(path=os.path.join(path, "segmentation"))):
progress_bar.update(n=1)
# Load segmentation
segmentation = nibabel.load(filename=os.path.join(path, "segmentation", file)).get_fdata() > 0
# Calc gradients
gradients = generic_gradient_magnitude(segmentation, sobel).astype(np.float32)
# Erase gradient to get border region
with torch.no_grad():
border_region = F.max_pool3d(torch.from_numpy(gradients)[None, None].cuda(),
kernel_size=(13, 13, 13), stride=(1, 1, 1), padding=(6, 6, 6))[0, 0].cpu()
# Save borders
torch.save(border_region.bool(),
os.path.join(path, "border", file.replace("segmentation", "border").replace(".nii", ".pt")))
# Close progress bar
progress_bar.close()