-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathkelpie.py
147 lines (126 loc) · 8.66 KB
/
kelpie.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from typing import Tuple, Any
from dataset import Dataset
from prefilters.no_prefilter import NoPreFilter
from prefilters.prefilter import TYPE_PREFILTER, TOPOLOGY_PREFILTER, NO_PREFILTER
from prefilters.type_based_prefilter import TypeBasedPreFilter
from prefilters.topology_prefilter import TopologyPreFilter
from relevance_engines.post_training_engine import PostTrainingEngine
from link_prediction.models.model import Model
from explanation_builders.stochastic_necessary_builder import StochasticNecessaryExplanationBuilder
from explanation_builders.stochastic_sufficient_builder import StochasticSufficientExplanationBuilder
class Kelpie:
"""
The Kelpie object is the overall manager of the explanation process.
It implements the whole explanation pipeline, requesting the suitable operations
to the Pre-Filter, Explanation Builder and Relevance Engine modules.
"""
DEFAULT_MAX_LENGTH = 4
def __init__(self,
model: Model,
dataset: Dataset,
hyperparameters: dict,
prefilter_type: str,
relevance_threshold: float = None,
max_explanation_length: int = DEFAULT_MAX_LENGTH):
"""
Kelpie object constructor.
:param model: the model to explain
:param dataset: the dataset used to train the model
:param hyperparameters: the hyperparameters of the model and of its optimization process
:param prefilter_type: the type of prefilter to employ
:param relevance_threshold: the threshold of relevance that, if exceeded, leads to explanation acceptance
:param max_explanation_length: the maximum number of facts that the explanations to extract can contain
"""
self.model = model
self.dataset = dataset
self.hyperparameters = hyperparameters
self.relevance_threshold = relevance_threshold
self.max_explanation_length = max_explanation_length
if prefilter_type == TOPOLOGY_PREFILTER:
self.prefilter = TopologyPreFilter(model=model, dataset=dataset)
elif prefilter_type == TYPE_PREFILTER:
self.prefilter = TypeBasedPreFilter(model=model, dataset=dataset)
elif prefilter_type == NO_PREFILTER:
self.prefilter = NoPreFilter(model=model, dataset=dataset)
else:
self.prefilter = TopologyPreFilter(model=model, dataset=dataset)
self.engine = PostTrainingEngine(model=model,
dataset=dataset,
hyperparameters=hyperparameters)
def explain_sufficient(self,
sample_to_explain: Tuple[Any, Any, Any],
perspective: str,
num_promising_samples: int = 50,
num_entities_to_convert: int = 10,
entities_to_convert: list = None):
"""
This method extracts sufficient explanations for a specific sample,
from the perspective of either its head or its tail.
:param sample_to_explain: the sample to explain
:param perspective: a string conveying the perspective of the requested explanations.
It can be either "head" or "tail":
- if "head", Kelpie answers the question
"given the sample head and relation, why is the sample tail predicted as tail?"
- if "tail", Kelpie answers the question
"given the sample relation and tail, why is the sample head predicted as head?"
:param num_promising_samples: the number of samples relevant to the sample to explain
that must be identified and added to the extracted similar entities
to verify whether they boost the target prediction or not
:param num_entities_to_convert: the number of entities to convert to extract
(if they have to be extracted)
:param entities_to_convert: the entities to convert
(if they are passed instead of having to be extracted)
:return: two lists:
the first one contains, for each relevant n-ple extracted, a couple containing
- that relevant sample
- its value of global relevance across the entities to convert
the second one contains the list of entities that the extractor has tried to convert
in the sufficient explanation process
"""
most_promising_samples = self.prefilter.top_promising_samples_for(sample_to_explain=sample_to_explain,
perspective=perspective,
top_k=num_promising_samples)
explanation_builder = StochasticSufficientExplanationBuilder(model=self.model,
dataset=self.dataset,
hyperparameters=self.hyperparameters,
sample_to_explain=sample_to_explain,
perspective=perspective,
num_entities_to_convert=num_entities_to_convert,
entities_to_convert=entities_to_convert,
relevance_threshold=self.relevance_threshold,
max_explanation_length=self.max_explanation_length)
explanations_with_relevance = explanation_builder.build_explanations(samples_to_add=most_promising_samples)
return explanations_with_relevance, explanation_builder.entities_to_convert
def explain_necessary(self,
sample_to_explain: Tuple[Any, Any, Any],
perspective: str,
num_promising_samples: int = 50):
"""
This method extracts necessary explanations for a specific sample,
from the perspective of either its head or its tail.
:param sample_to_explain: the sample to explain
:param perspective: a string conveying the perspective of the requested explanations.
It can be either "head" or "tail":
- if "head", Kelpie answers the question
"given the sample head and relation, why is the sample tail predicted as tail?"
- if "tail", Kelpie answers the question
"given the sample relation and tail, why is the sample head predicted as head?"
:param num_promising_samples: the number of samples relevant to the sample to explain
that must be identified and removed from the entity under analysis
to verify whether they worsen the target prediction or not
:return: a list containing for each relevant n-ple extracted, a couple containing
- that relevant n-ple
- its value of relevance
"""
most_promising_samples = self.prefilter.top_promising_samples_for(sample_to_explain=sample_to_explain,
perspective=perspective,
top_k=num_promising_samples)
explanation_builder = StochasticNecessaryExplanationBuilder(model=self.model,
dataset=self.dataset,
hyperparameters=self.hyperparameters,
sample_to_explain=sample_to_explain,
perspective=perspective,
relevance_threshold=self.relevance_threshold,
max_explanation_length=self.max_explanation_length)
explanations_with_relevance = explanation_builder.build_explanations(samples_to_remove=most_promising_samples)
return explanations_with_relevance