-
Notifications
You must be signed in to change notification settings - Fork 4
/
reset_experiment.py
160 lines (133 loc) · 5.5 KB
/
reset_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Main script to evaluate xAI algorithms. It generates the raw scores for every test and every xAI."""
import datetime
import inspect
import logging
import sys
import time
import traceback
from typing import Type
from explainers.explainer_superclass import Explainer, UnsupportedModelException
from src.scoring import get_details
from src.explainer import valid_explainers
from src.io import *
from src.test import valid_tests
from src.utils import TimeoutException
from tests.test_superclass import Test
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%H:%M:%S",
)
from contextlib import contextmanager
import threading
import _thread
def not_string(e):
if isinstance(e, str):
return None
return e
@contextmanager
def time_limit(seconds, msg=''):
timer = threading.Timer(seconds, lambda: _thread.interrupt_main())
timer.start()
try:
yield
except KeyboardInterrupt:
raise TimeoutException("Timed out for operation {}".format(msg))
finally:
# if the action ends in specified time, timer is canceled
timer.cancel()
def get_empty_result(*args):
return {}
def empty_results(_iterable):
return [get_empty_result() for _ in _iterable]
def append_empty_row(result_df, name):
empty_row_df = pd.DataFrame(
[empty_results(result_df.columns)],
index=[name],
columns=result_df.columns,
)
return pd.concat([result_df, empty_row_df])
def compatible(test_class, explainer_class):
""" test if the xai generate the kind of explanation required by the test """
for explanation in ['importance', 'attribution', 'interaction']:
is_explanation_required_by_test = explanation in inspect.getfullargspec(test_class.score).args
if_explainer_able_to_provide_it = explainer_class.__dict__.get(f'output_{explanation}', False)
if is_explanation_required_by_test and if_explainer_able_to_provide_it:
return True
else:
return False
TIME_LIMIT = 3*60*60 # seconds # todo [before submission] increase # src https://stackoverflow.com/questions/366682/how-to-limit-execution-time-of-a-function-call
def format_results(score=None, time=None): # todo add note for unsopported Model exception
results = {}
if score is not None:
results['score'] = score
if time is not None:
results['time'] = time
# 'time': time.time() - start_time,
results['Last_updated']: str(datetime.datetime.now())
return results
def run_experiment(test_class: Type[Test], explainer_class: Type[Explainer]):
print(test_class.__name__, explainer_class.__name__)
if not compatible(test_class, explainer_class):
print('not compatible')
return format_results()
# Init test
test = test_class()
start_time = time.time()
# Init Explainer
try:
arg = dict(**{key: getattr(test, key) for key in dir(test) if key[:2] != '__'})
_explainer = explainer_class(**arg)
except UnsupportedModelException:
print('UnsupportedModelException')
return format_results()
except Exception as e:
exc_info = sys.exc_info()
traceback.print_exception(*exc_info)
return format_results()
# Explain
try:
with time_limit(TIME_LIMIT, 'explain'):
try:
_explainer.explain(dataset_to_explain=test.dataset_to_explain, truth_to_explain=test.truth_to_explain)
execution_time = time.time() - start_time
except TimeoutException:
raise
except Exception as e:
exc_info = sys.exc_info()
traceback.print_exception(*exc_info)
print('Err while explaining')
time.sleep(.1)
return format_results()
_explainer.check_explanation(test.dataset_to_explain)
except TimeoutException as e:
print("Timed out!")
execution_time = TIME_LIMIT
# Score the output
arg = {key: not_string(_explainer.__dict__.get(key)) for key in ['attribution', 'importance', 'interaction']}
score = test.score(**arg)
return format_results(score=score, time=execution_time)
if __name__ == "__main__":
print(f'Explainers: {len(valid_explainers)}')
print(f'Tests: {len(valid_tests)}')
result_df = load_results()
if result_df is None: # todo [after acceptance] move to io.py
result_df = pd.DataFrame(index=[e.name for e in valid_explainers],
columns=[t.name for t in valid_tests]).applymap(get_empty_result)
try:
for explainer_class in valid_explainers:
if explainer_class.name not in result_df.index:
result_df = append_empty_row(result_df, explainer_class.name)
for test_class in valid_tests:
if test_class.name not in result_df.columns: # todo [after acceptance] check if this line is really important
result_df[test_class.name] = empty_results(result_df.index)
result = run_experiment(test_class, explainer_class)
print('old result', result_df.loc[explainer_class.name, test_class.name])
result_df.at[explainer_class.name, test_class.name] = result
print('new result', result_df.loc[explainer_class.name, test_class.name])
except KeyboardInterrupt:
pass
print(result_df)
save_results_safe(result_df)
summary_df, eligible_points_df, score_df = get_details(result_df)
print('Now run src/aggregate_data.py')