-
Notifications
You must be signed in to change notification settings - Fork 38
/
ADNI_dataset.py
63 lines (53 loc) · 1.6 KB
/
ADNI_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import csv
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import os
from torchvision import transforms
from skimage.transform import resize
from nilearn import surface
import nibabel as nib
class ADNIdataset(Dataset):
def __init__(self, root='../ADNI', augmentation=False):
self.root = root
self.basis = 'FreeSurfer_Cross-Sectional_Processing_brainmask'
self.augmentation = augmentation
f = open('CN_list.csv','r')
rdr = csv.reader(f)
name = []
labels = []
date = []
for line in rdr:
[month,day,year] = line[9].split('/')
month = month.zfill(2)
date.append(year+'-'+month+'-'+day)
name.append(line[1])
name = np.asarray(name)
date = np.asarray(date)
self.name =name
self.date =date
def __len__(self):
return len(self.name)
def __getitem__(self, index):
path = os.path.join(self.root,self.name[index],self.basis)
files = os.listdir(path)
for file in files:
if file[:10] == self.date[index]:
rname = file
aname = os.listdir(os.path.join(path,rname))[0]
path = os.path.join(path,rname,aname,'mri')
img = nib.load(os.path.join(path,'image.nii'))
img = np.swapaxes(img.get_data(),1,2)
img = np.flip(img,1)
img = np.flip(img,2)
sp_size = 64
img = resize(img, (sp_size,sp_size,sp_size), mode='constant')
if self.augmentation:
random_n = torch.rand(1)
random_i = 0.3*torch.rand(1)[0]+0.7
if random_n[0] > 0.5:
img = np.flip(img,0)
img = img*random_i.data.cpu().numpy()
imageout = torch.from_numpy(img).float().view(1,sp_size,sp_size,sp_size)
imageout = imageout*2-1
return imageout