forked from nilearn/nilearn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_haxby_simple.py
119 lines (83 loc) · 3.69 KB
/
plot_haxby_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Simple example of decoding: the Haxby data
==============================================
Here is a simple example of decoding, reproducing the Haxby 2001
study on a face vs house discrimination task in a mask of the ventral
stream.
"""
### Load haxby dataset ########################################################
from nilearn import datasets
data = datasets.fetch_haxby()
### Load Target labels ########################################################
import numpy as np
# Load target information as string and give a numerical identifier to each
labels = np.recfromcsv(data.session_target[0], delimiter=" ")
# scikit-learn >= 0.14 supports text labels. You can replace this line by:
# target = labels['labels']
_, target = np.unique(labels['labels'], return_inverse=True)
### Keep only data corresponding to faces or cat ##############################
condition_mask = np.logical_or(labels['labels'] == 'face',
labels['labels'] == 'cat')
target = target[condition_mask]
### Prepare the data: apply the mask ##########################################
from nilearn.input_data import NiftiMasker
# For decoding, standardizing is often very important
nifti_masker = NiftiMasker(mask=data.mask_vt[0], standardize=True)
# We give the nifti_masker a filename and retrieve a 2D array ready
# for machine learning with scikit-learn
fmri_masked = nifti_masker.fit_transform(data.func[0])
# Restrict the classification to the face vs house discrimination
fmri_masked = fmri_masked[condition_mask]
### Prediction ################################################################
# Here we use a Support Vector Classification, with a linear kernel
from sklearn.svm import SVC
svc = SVC(kernel='linear')
# And we run it
svc.fit(fmri_masked, target)
prediction = svc.predict(fmri_masked)
### Cross-validation ##########################################################
from sklearn.cross_validation import KFold
cv = KFold(n=len(fmri_masked), n_folds=5)
cv_scores = []
for train, test in cv:
svc.fit(fmri_masked[train], target[train])
prediction = svc.predict(fmri_masked[test])
cv_scores.append(np.sum(prediction == target[test])
/ float(np.size(target[test])))
print cv_scores
### Unmasking #################################################################
# Retrieve the SVC discriminating weights
coef_ = svc.coef_
# Reverse masking thanks to the Nifti Masker
coef_niimg = nifti_masker.inverse_transform(coef_)
# Use nibabel to save the coefficients as a Nifti image
import nibabel
nibabel.save(coef_niimg, 'haxby_svc_weights.nii')
### Visualization #############################################################
import matplotlib.pyplot as plt
### Create the figure and plot the first EPI image as a background
plt.figure(figsize=(3, 5))
epi_img = nibabel.load(data.func[0])
plt.imshow(np.rot90(epi_img.get_data()[..., 27, 0]),
interpolation='nearest', cmap=plt.cm.gray)
### Plot the SVM weights
weights = coef_niimg.get_data()
# We use a masked array so that the voxels at '-1' are displayed transparently
weights = np.ma.masked_array(weights, weights == 0)
plt.imshow(np.rot90(weights[..., 27, 0]), cmap=plt.cm.hot,
interpolation='nearest')
plt.axis('off')
plt.title('SVM weights')
plt.tight_layout()
### Visualize the mask ########################################################
mask = nifti_masker.mask_img_.get_data()
plt.figure()
plt.axis('off')
plt.imshow(np.rot90(nibabel.load(data.func[0]).get_data()[..., 27, 0]),
interpolation='nearest', cmap=plt.cm.gray)
ma = np.ma.masked_equal(mask, 0)
plt.imshow(np.rot90(ma[..., 27]), interpolation='nearest', cmap=plt.cm.autumn,
alpha=0.5)
plt.title("Mask")
plt.tight_layout()
plt.show()