-
Notifications
You must be signed in to change notification settings - Fork 2
/
plot_support_prob.py
63 lines (48 loc) · 1.41 KB
/
plot_support_prob.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import sys
import logging
import json
import glob
import copy
import numpy as np
import pandas as pd
from scipy.stats import pearsonr
import matplotlib
matplotlib.use("Agg")
from matplotlib import pyplot as plt
import seaborn as sns
from evaluate_model import load_easier_net
def parse_args(args):
""" parse command line arguments """
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--n-inputs", type=int,
)
parser.add_argument("--corr", type=float)
parser.add_argument(
"--fitted-model-files", type=str,
)
parser.add_argument(
"--out-support-file", type=str,
)
parser.set_defaults()
args = parser.parse_args()
args.fitted_model_files = glob.glob(args.fitted_model_files)
return args
def main(args=sys.argv[1:]):
args = parse_args(args)
# Get the support probabilities
supports = np.zeros(args.n_inputs)
print("NUM FILES", len(args.fitted_model_files))
for f in args.fitted_model_files:
model, _ = load_easier_net(f)
support = model.support()
for i in support:
supports[i] += 1
supports /= len(args.fitted_model_files)
print(supports)
data = pd.DataFrame({"input": np.arange(args.n_inputs), "prob_support": supports,})
data["corr"] = args.corr
data.to_csv(args.out_support_file)
if __name__ == "__main__":
main(sys.argv[1:])