-
Notifications
You must be signed in to change notification settings - Fork 0
/
prueba_SHAP.py
58 lines (41 loc) · 1.5 KB
/
prueba_SHAP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch, torchvision
from torch import nn
from torchvision import transforms, models, datasets
import shap
import json
import numpy as np
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
def normalize(image):
if image.max() > 1:
image /= 255
image = (image - mean) / std
# in addition, roll the axis so that they suit pytorch
return torch.tensor(image.swapaxes(-1, 1).swapaxes(2, 3)).float()
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
def normalize(image):
if image.max() > 1:
image /= 255
image = (image - mean) / std
# in addition, roll the axis so that they suit pytorch
return torch.tensor(image.swapaxes(-1, 1).swapaxes(2, 3)).float()
# load the model
model = models.vgg16(pretrained=True)
model.cuda()
model.eval()
X, y = shap.datasets.imagenet50()
X /= 255
to_explain = X[[39, 41]]
# load the ImageNet class names
url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
fname = shap.datasets.cache(url)
with open(fname) as f:
class_names = json.load(f)
e = shap.GradientExplainer((model, model.features[7]), normalize(X).cuda())
shap_values, indexes = e.shap_values(normalize(to_explain).cuda(), ranked_outputs=2, nsamples=200)
# get the names for the classes
index_names = np.vectorize(lambda x: class_names[str(x)][1])(indexes.cpu())
# plot the explanations
shap_values = [np.swapaxes(np.swapaxes(s, 2, 3), 1, -1) for s in shap_values]
shap.image_plot(shap_values, to_explain, index_names)