Skip to content

Commit

Permalink
Added generate-minimize strategy to meta-explainers
Browse files Browse the repository at this point in the history
  • Loading branch information
MarioTheOne committed Sep 26, 2024
1 parent b57509b commit 442728f
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 238 deletions.
357 changes: 124 additions & 233 deletions lab/1-evaluation_pipeline.ipynb

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions lab/config/meta/generate_minimize.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"experiment" : {
"scope": "meta-overshoot-reduce",
"parameters" : {
"lock_release_tout":120,
"propagate":[
{"in_sections" : ["doe-triplets/explainer"],"params" : {"fold_id": 1}},
{"in_sections" : ["doe-triplets/oracle"],"params" : {"fold_id": -1,"retrain":true}},
{"in_sections": ["doe-triplets/dataset"],"params": { "compose_man" : "lab/config/snippets/datasets/centr_and_weights.json" }}
],
"expand" : { "folds" : ["doe-triplets/explainer"], "triplets" : true }
}
},
"doe-triplets":[
{
"compose_do": "./lab/config/snippets/do-pairs/TCR-128-28-0.25_TCO.json",
"explainer": {
"class": "src.explainer.future.meta.generate_minimize.GenerateMinimizeExplainer",
"parameters":{
"fold_id": 1,
"generator": {
"class": "src.explainer.future.generative.rsgg.RSGG",
"parameters":{
"epochs": 500
}
},
"minimizer": {
"class": "src.explainer.future.meta.minimizer.random.RandomMinimizer",
"parameters": {}
}
}
}
}
],
"evaluator": {
"class": "src.evaluation.future.evaluator.Evaluator",
"parameters": {
"compose_pip": "./lab/config/snippets/default_pipeline.json"
}
},
"compose_strs" : "./lab/config/snippets/default_store_paths.json"
}
32 changes: 32 additions & 0 deletions lab/config/meta/rsgg.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"experiment" : {
"scope": "meta-overshoot-reduce",
"parameters" : {
"lock_release_tout":120,
"propagate":[
{"in_sections" : ["doe-triplets/explainer"],"params" : {"fold_id": 1}},
{"in_sections" : ["doe-triplets/oracle"],"params" : {"fold_id": -1,"retrain":true}},
{"in_sections": ["doe-triplets/dataset"],"params": { "compose_man" : "lab/config/snippets/datasets/centr_and_weights.json" }}
],
"expand" : { "folds" : ["doe-triplets/explainer"], "triplets" : true }
}
},
"doe-triplets":[
{
"compose_do": "./lab/config/snippets/do-pairs/TCR-128-28-0.25_TCO.json",
"explainer": {
"class": "src.explainer.future.generative.rsgg.RSGG",
"parameters":{
"epochs": 500
}
}
}
],
"evaluator": {
"class": "src.evaluation.future.evaluator.Evaluator",
"parameters": {
"compose_pip": "./lab/config/snippets/default_pipeline.json"
}
},
"compose_strs" : "./lab/config/snippets/default_store_paths.json"
}
42 changes: 42 additions & 0 deletions src/explainer/future/meta/generate_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from src.core.factory_base import get_class, get_instance_kvargs
from src.core.trainable_base import Trainable
from src.utils.cfg_utils import inject_dataset, inject_oracle
import src.utils.explanations.functions as exp_tools
from src.future.explanation.local.graph_counterfactual import LocalGraphCounterfactualExplanation


class GenerateMinimizeExplainer(Explainer, Trainable):
Expand All @@ -11,6 +13,7 @@ class GenerateMinimizeExplainer(Explainer, Trainable):

def check_configuration(self):
super().check_configuration()
self.logger = self.context.logger

if 'generator' not in self.local_config['parameters']:
raise Exception('A generate-minimize method requires a generator')
Expand All @@ -21,10 +24,14 @@ def check_configuration(self):
# Inject the oracle and the dataset into the generator explainer
inject_dataset(self.local_config['parameters']['generator'], self.dataset)
inject_oracle(self.local_config['parameters']['generator'], self.oracle)
# all the components should use the same fold
self.local_config['parameters']['generator']['parameters']['fold_id'] = self.local_config['parameters']['fold_id']

# Inject the oracle and the dataset into the minimizer
inject_dataset(self.local_config['parameters']['minimizer'], self.dataset)
inject_oracle(self.local_config['parameters']['minimizer'], self.oracle)
# all the components should use the same fold
self.local_config['parameters']['minimizer']['parameters']['fold_id'] = self.local_config['parameters']['fold_id']


def init(self):
Expand All @@ -37,5 +44,40 @@ def init(self):
{'context':self.context,'local_config': self.local_config['parameters']['minimizer']})


def real_fit(self):
# Chose if the components will be trained at creation time or in this fit
pass


def explain(self, instance):

# Using the generator to obtain an initial explanation
initial_explanation = self.explanation_generator.explain(instance)
initial_cf = initial_explanation.counterfactual_instances[0]

# Getting the predicted label of the initial explanation
initial_cf_label = self.oracle.predict(initial_cf)

if initial_cf == instance.label:
# the generator was not able to produce a counterfactual
# so we can inmediately return, there is no point in minimizing
return initial_explanation

# Try to minimize the distance between the counterfactual example and the original instance
minimum_cf = self.explanation_minimizer.minimize(instance, initial_cf)

minimal_explanation = LocalGraphCounterfactualExplanation(context=self.context,
dataset=self.dataset,
oracle=self.oracle,
explainer=self,
input_instance=instance,
counterfactual_instances=[minimum_cf])

return minimal_explanation


def write(self):
pass

def read(self):
pass
2 changes: 1 addition & 1 deletion src/explainer/future/meta/minimizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ExplanationMinimizer(Trainable):

def check_configuration(self):
super().check_configuration()
pass
self.logger = self.context.logger


def init(self):
Expand Down
23 changes: 19 additions & 4 deletions src/explainer/future/meta/minimizer/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@
from src.explainer.future.meta.minimizer.base import ExplanationMinimizer
from src.dataset.instances.base import DataInstance
from src.dataset.instances.graph import GraphInstance

from src.utils.comparison import get_all_edge_differences, get_edge_differences

class RandomMinimizer(ExplanationMinimizer):

def minimize(self, instance, cf_instance) -> DataInstance:
changed_edges, _, _ = self.get_all_edge_differences(instance, [cf_instance])
# Get the changes between the original graph and the initial counterfactual
changed_edges, _, _ = get_all_edge_differences(instance, [cf_instance])

# apply the backward search to minimize the counterfactual
minimal_cf = self.oblivious_backward_search(instance, cf_instance, changed_edges)

# Return the minimal counterfactual
return minimal_cf


def oblivious_backward_search(self, instance, cf_instance, changed_edges, k=5, maximum_oracle_calls=2000):
'''
Expand Down Expand Up @@ -58,6 +66,13 @@ def oblivious_backward_search(self, instance, cf_instance, changed_edges, k=5, m

result_cf = GraphInstance(id=instance.id, label=0, data=gc, directed=instance.directed)
self.dataset.manipulate(result_cf)
result_cf.label = self.oracle.predict(reduced_cf_inst)
result_cf.label = self.oracle.predict(result_cf)

return result_cf


def write(self):
pass

return result_cf
def read(self):
pass

0 comments on commit 442728f

Please sign in to comment.