-
Notifications
You must be signed in to change notification settings - Fork 0
/
xregions.py
131 lines (115 loc) · 4.66 KB
/
xregions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import json
import sys
import time
import argparse
import logging
import random
from src.model import Model
from src.explainer import ExplanationProgram
from benchmark.benchmark import benchmark_all
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="[%(levelname)s|%(asctime)s] %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p"
)
SEED = 21023
def get_lims(fname):
lims = {}
with open(fname, "r") as f:
line = f.readline()
while line:
line = line.split(",")
lims[int(line[0])] = (float(line[1]), float(line[2]))
line = f.readline()
return lims
def random_x(lims):
return [random.uniform(l[0], l[1]) for l in lims.values()]
def main():
parser = argparse.ArgumentParser(
description="Demonstration for loading and printing XGBoost model.")
parser.add_argument("-m", "--model",
type=str,
required=True,
help="Name of a model in the models folder.")
action_group = parser.add_mutually_exclusive_group(required=True)
action_group.add_argument("-E", "--enumerate",
type=str,
help="Enumerate all explanations for an instance.")
action_group.add_argument("-e", "--explain",
type=str,
help="Generate one explanation for a given instance.")
action_group.add_argument("--benchmark-all",
action="store_true",
required=False,
help="Run benchmarks for given seed generation method.")
parser.add_argument("--loglevel",
type=str,
required=False,
help="Program logging level.")
parser.add_argument("--block-score",
type=bool,
default=False,
required=False,
help="Whether or not to block score when enumerating.")
parser.add_argument("--seed-gen",
type=str,
default="rand",
required=False,
help="Seed generation method: (rand|min|max)")
args = parser.parse_args()
if args.benchmark_all:
benchmark_all(args.seed_gen)
return
if args.benchmark_explain:
benchmark_explain(args.model)
return
if args.benchmark_enumerate:
benchmark_enumerate(args.model)
return
if args.loglevel:
numeric_level = getattr(logging, args.loglevel.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError("Invalid log level: %s" % args.loglevel)
logging.getLogger().setLevel(numeric_level)
with open(f"models/{args.model}.json", "r") as fd:
model = json.load(fd)
model = Model(model)
logging.info(f"successfully initialised models/{args.model}.json")
lims = get_lims(f"models/{args.model}.lims")
logging.info(f"successfully initialised domain limits models/{args.model}.json")
seed_gen = args.seed_gen
block_score = args.block_score
instance = args.explain if args.explain is not None else args.enumerate
if instance == "random":
random.seed(SEED)
instance = random_x(lims)
else:
instance = [float(x) for x in instance.split(",")]
program = ExplanationProgram(model, limits=lims, seed_gen=seed_gen, mpath=f"models/{args.model}.json")
logging.info(
"\nPROGRAM INFO:\n" + \
f"\tObjective: {model.objective}\n"
f"\tClasses: {2 if 'binary' in model.objective else model.num_output_group}\n" + \
f"\tFeatures: {model.num_feature}\n" + \
f"\tTrees: {model.num_trees}\n" + \
f"\tSeed Generation: {seed_gen}\n" + \
f"\tThresholds: {program.fs_info.n_thresholds()}\n" + \
f"\tPairs: {program.fs_info.n_pairs()}\n" + \
f"\tPossible Regions: {program.fs_info.n_regions()}"
)
c = program.entailer.predict(instance)
if args.explain is not None:
logging.info(f"EXPLAIN: {instance} -> {c} | block_score: {block_score}")
r = program.explain(instance)
logging.info(f"COMPLETE:\n{r}")
elif args.enumerate is not None:
logging.info(
"THRESHOLDS:\n" + \
"\n".join([f"{i}: {program.fs_info.get_domain(i)}" for i in sorted(model.thresholds.keys())])
)
logging.info(f"ENUMERATE EXPLANATIONS: {instance} -> {c} | block_score: {block_score}")
for r in program.enumerate_explanations(instance, block_score=block_score):
pass
if __name__ == "__main__":
main()