-
Notifications
You must be signed in to change notification settings - Fork 55
/
loader.py
82 lines (70 loc) · 3.37 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
from glob import glob
import numpy as np
from torch.utils.data import Dataset, DataLoader
class ct_dataset(Dataset):
def __init__(self, mode, load_mode, saved_path, test_patient, patch_n=None, patch_size=None, transform=None):
assert mode in ['train', 'test'], "mode is 'train' or 'test'"
assert load_mode in [0,1], "load_mode is 0 or 1"
input_path = sorted(glob(os.path.join(saved_path, '*_input.npy')))
target_path = sorted(glob(os.path.join(saved_path, '*_target.npy')))
self.load_mode = load_mode
self.patch_n = patch_n
self.patch_size = patch_size
self.transform = transform
if mode == 'train':
input_ = [f for f in input_path if test_patient not in f]
target_ = [f for f in target_path if test_patient not in f]
if load_mode == 0: # batch data load
self.input_ = input_
self.target_ = target_
else: # all data load
self.input_ = [np.load(f) for f in input_]
self.target_ = [np.load(f) for f in target_]
else: # mode =='test'
input_ = [f for f in input_path if test_patient in f]
target_ = [f for f in target_path if test_patient in f]
if load_mode == 0:
self.input_ = input_
self.target_ = target_
else:
self.input_ = [np.load(f) for f in input_]
self.target_ = [np.load(f) for f in target_]
def __len__(self):
return len(self.target_)
def __getitem__(self, idx):
input_img, target_img = self.input_[idx], self.target_[idx]
if self.load_mode == 0:
input_img, target_img = np.load(input_img), np.load(target_img)
if self.transform:
input_img = self.transform(input_img)
target_img = self.transform(target_img)
if self.patch_size:
input_patches, target_patches = get_patch(input_img,
target_img,
self.patch_n,
self.patch_size)
return (input_patches, target_patches)
else:
return (input_img, target_img)
def get_patch(full_input_img, full_target_img, patch_n, patch_size):
assert full_input_img.shape == full_target_img.shape
patch_input_imgs = []
patch_target_imgs = []
h, w = full_input_img.shape
new_h, new_w = patch_size, patch_size
for _ in range(patch_n):
top = np.random.randint(0, h-new_h)
left = np.random.randint(0, w-new_w)
patch_input_img = full_input_img[top:top+new_h, left:left+new_w]
patch_target_img = full_target_img[top:top+new_h, left:left+new_w]
patch_input_imgs.append(patch_input_img)
patch_target_imgs.append(patch_target_img)
return np.array(patch_input_imgs), np.array(patch_target_imgs)
def get_loader(mode='train', load_mode=0,
saved_path=None, test_patient='L506',
patch_n=None, patch_size=None,
transform=None, batch_size=32, num_workers=6):
dataset_ = ct_dataset(mode, load_mode, saved_path, test_patient, patch_n, patch_size, transform)
data_loader = DataLoader(dataset=dataset_, batch_size=batch_size, shuffle=True, num_workers=num_workers)
return data_loader