-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy path4_reconstruct_shape_image.py
184 lines (136 loc) · 6.17 KB
/
4_reconstruct_shape_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
'''
Reconstruct colored artificial shapes from CNN features decoded from the brain.
- ROI: VC
- Layers: all conv and fc layers
- Reconstruction algorithm: Without DGN + LBFGS
'''
import os
import pickle
from datetime import datetime
from itertools import product
import caffe
import numpy as np
import PIL.Image
import scipy.io as sio
from icnn.icnn_lbfgs import reconstruct_image # Without DGN
from icnn.utils import clip_extreme_value, estimate_cnn_feat_std, normalise_img
# Settings ###################################################################
# GPU usage settings
caffe.set_mode_gpu()
caffe.set_device(0)
# Decoded features settings
decoded_features_dir = './data/decodedfeatures'
decode_feature_filename = lambda net, layer, subject, roi, image_type, image_label: os.path.join(decoded_features_dir, image_type, net, layer, subject, roi,
'%s-%s-%s-%s-%s-%s.mat' % (image_type, net, layer, subject, roi, image_label))
# Data settings
results_dir = './results'
subjects_list = ['S1', 'S2', 'S3']
rois_list = ['VC']
network = 'VGG19'
# Images in figure 3A
image_type = 'color_shape'
image_label_list = ['Img0001',
'Img0002',
'Img0003',
'Img0004',
'Img0005',
'Img0006',
'Img0007',
'Img0008',
'Img0009',
'Img0010',
'Img0011',
'Img0012',
'Img0013',
'Img0014',
'Img0015']
max_iteration = 200
# Main #######################################################################
# Initialize CNN -------------------------------------------------------------
# Average image of ImageNet
img_mean_file = './data/ilsvrc_2012_mean.npy'
img_mean = np.load(img_mean_file)
img_mean = np.float32([img_mean[0].mean(), img_mean[1].mean(), img_mean[2].mean()])
# CNN model
model_file = './net/VGG_ILSVRC_19_layers/VGG_ILSVRC_19_layers.caffemodel'
prototxt_file = './net/VGG_ILSVRC_19_layers/VGG_ILSVRC_19_layers.prototxt'
channel_swap = (2, 1, 0)
net = caffe.Classifier(prototxt_file, model_file, mean=img_mean, channel_swap=channel_swap)
h, w = net.blobs['data'].data.shape[-2:]
net.blobs['data'].reshape(1, 3, h, w)
# Initial image for the optimization (here we use the mean of ilsvrc_2012_mean.npy as RGB values)
initial_image = np.zeros((h, w, 3), dtype='float32')
initial_image[:, :, 0] = img_mean[2].copy()
initial_image[:, :, 1] = img_mean[1].copy()
initial_image[:, :, 2] = img_mean[0].copy()
# Feature SD estimated from true CNN features of 10000 images
feat_std_file = './data/estimated_vgg19_cnn_feat_std.mat'
feat_std0 = sio.loadmat(feat_std_file)
# CNN Layers (all conv and fc layers)
layers = [layer for layer in net.blobs.keys() if 'conv' in layer or 'fc' in layer]
# Setup results directory ----------------------------------------------------
save_dir_root = os.path.join(results_dir, os.path.splitext(__file__)[0])
if not os.path.exists(save_dir_root):
os.makedirs(save_dir_root)
# Set reconstruction options -------------------------------------------------
opts = {
# The loss function type: {'l2','l1','inner','gram'}
'loss_type': 'l2',
# The maximum number of iterations
'maxiter': max_iteration,
# The initial image for the optimization (setting to None will use random noise as initial image)
'initial_image': initial_image,
# Display the information on the terminal or not
'disp': True
}
# Save the optional parameters
with open(os.path.join(save_dir_root, 'options.pkl'), 'w') as f:
pickle.dump(opts, f)
# Reconstrucion --------------------------------------------------------------
for subject, roi, image_label in product(subjects_list, rois_list, image_label_list):
print('')
print('Subject: ' + subject)
print('ROI: ' + roi)
print('Image label: ' + image_label)
print('')
save_dir = os.path.join(save_dir_root, subject, roi)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Load the decoded CNN features
features = {}
for layer in layers:
# The file full name depends on the data structure for decoded CNN features
file_name = decode_feature_filename(network, layer, subject, roi, image_type, image_label)
feat = sio.loadmat(file_name)['feat']
if 'fc' in layer:
feat = feat.reshape(feat.size)
# Correct the norm of the decoded CNN features
feat_std = estimate_cnn_feat_std(feat)
feat = (feat / feat_std) * feat_std0[layer]
features.update({layer: feat})
# Weight of each layer in the total loss function
# Norm of the CNN features for each layer
feat_norm = np.array([np.linalg.norm(features[layer]) for layer in layers], dtype='float32')
# Use the inverse of the squared norm of the CNN features as the weight for each layer
weights = 1. / (feat_norm ** 2)
# Normalise the weights such that the sum of the weights = 1
weights = weights / weights.sum()
layer_weight = dict(zip(layers, weights))
opts.update({'layer_weight': layer_weight})
# Reconstruction
snapshots_dir = os.path.join(save_dir, 'snapshots', 'image-%s' % image_label)
recon_img, loss_list = reconstruct_image(features, net,
save_intermediate=True,
save_intermediate_path=snapshots_dir,
**opts)
# Save the results
# Save the raw reconstructed image
save_name = 'recon_img' + '-' + image_label + '.mat'
sio.savemat(os.path.join(save_dir, save_name), {'recon_img': recon_img})
# To better display the image, clip pixels with extreme values (0.02% of
# pixels with extreme low values and 0.02% of the pixels with extreme high
# values). And then normalise the image by mapping the pixel value to be
# within [0,255].
save_name = 'recon_img_normalized' + '-' + image_label + '.jpg'
PIL.Image.fromarray(normalise_img(clip_extreme_value(recon_img, pct=4))).save(os.path.join(save_dir, save_name))
print('Done')