-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_saliency.py
59 lines (38 loc) · 1.46 KB
/
run_saliency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from Utils.dict2Class import Dict2Class
import torch
torch.autograd.set_detect_anomaly(True)
from Dataset.dataset import getDataset
from SaliencyNew.sal_test import test
from SaliencyNew.sal_train import train
from SaliencyMap.get_options import getOptions
from SaliencyMap.get_saliency import Saliency
from Model.saliencyModel import Model
import os
exp = "sage_hcp"
options = getOptions(exp)
torch.manual_seed(options.seed)
dataset = getDataset(options)
targetDir = options.workDir + options.logRelDir + "/" + options.expGroup + "/" + options.expName
saveName = options.saveName
# expDir = options.expGroup + "/" + options.expName
dname = options.datasets[0]
for fold in range(options.kFold):
print("\n running fold {} on device {} \n".format(fold, options.device))
# the trained model to be used in saliency map analysis
modelDir = targetDir +"/fold_{}/foldModel.torch".format(fold)
print(modelDir)
frozenModel = torch.load(modelDir).model
model = Model(options, frozenModel)
testDict = Dict2Class({
"dataset" : dataset,
"fold" : fold,
"model": model,
"saveName": saveName,
"dname": dname
})
train(testDict) # will do the same thing as if we are testing
# train and test separately
# test(testDict)
# saliency = Saliency(options)
# distScores = saliency.getDist()
# print(distScores)