Skip to content

Commit

Permalink
Merge pull request #186 from dianna-ai/hyperparameter_configs
Browse files Browse the repository at this point in the history
Added boilerplate for experiments and appropriate hyperparemeter configs
  • Loading branch information
WillemSpek authored Jun 19, 2023
2 parents 17ac312 + 17bbd75 commit 16f4ab4
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 0 deletions.
Empty file.
Empty file.
5 changes: 5 additions & 0 deletions relevance_maps_properties/experiments/distance_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import numpy as np

from scipy.stats import wasserstein_distance
from numpy.typing import NDArray

116 changes: 116 additions & 0 deletions relevance_maps_properties/experiments/hyperparameter_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import numpy as np

from typing import Optional, Iterable
from skimage.segmentation import slic
from sklearn.model_selection import ParameterGrid


def create_grid(parameters: object) -> list:
''' Convert parameter objects to a grid containing all possible parameter
combinations.
Args:
parameters: Parameters to use in the grid
Returns: All possible parameter combinations
'''
return list(ParameterGrid(parameters.__dict__))


class RISE_parameters(object):
'''Set up hyperparameters for RISE.
'''
def __init__(self,
p_keep: Optional[Iterable] = None,
feature_res: Optional[Iterable] = None,
n_masks: Optional[Iterable] = None,
random_state: Optional[Iterable[int]] = None):
'''
Args:
p_keep: probability to keep bit unmasked
feature_res: size of bitmask
n_masks: number of masks to use
random_state: random seed
'''
self.p_keep=p_keep
self.feature_res = feature_res
self.n_masks = n_masks
self.random_state = random_state


class LIME_parameters(object):
'''Set up hyperparamters for LIME.
NOTE: LIME segments the image using quickshift which is statically impleneted in
their package. We should discuss if we want to make this segmentation modifiable as a
hyperparameter by chanigng the LIME implementation and trying out a different segmentation algo.
'''
def __init__(self,
num_samples: Optional[Iterable] = None,
kernel_width: Optional[Iterable] = None,
feature_selection: Optional[Iterable] = None,
distance_metric: Optional[Iterable] = None,
segmentation_fn: Optional[Iterable] = None,
model_regressor: Optional[Iterable] = None,
random_state: Optional[Iterable] = None):
'''
Args:
num_samples: amount of instances to perturb
kernel_width: width to use for kernel to compute proximity
feature_selection: feature selection algorithm to select a priori
distance_metric: distance metric used to compute proximity
segmentation_fn: Segmentation algorithm to obtain superpixels
model_regressor: Surrogate model to use
random_state: random seed
'''
self.num_samples = num_samples
self.kernel_width = kernel_width
self.feature_selection = feature_selection
self.distance_metric = distance_metric
self.segmentation_fn = segmentation_fn
self.model_regressor = model_regressor
self.random_state = random_state


class SHAP_parameters(object):
''' Set up hyperparameters for KernelSHAP.'''
def __init__(self,
nsamples: Optional[Iterable] = None,
background: Optional[Iterable]= None,
sigma: Optional[Iterable] = None,
l1_reg: Optional[Iterable] = None,
random_state: Optional[Iterable] = None):
'''
Args:
nsamples: amount of combinations to use
background: background of masked image
sigma: gaussian kernel width
l1_reg: L1 regularization factor
random_state: random seed
'''
self.nsamples = nsamples,
self.background = background
self.sigma = sigma
self.l1_reg = l1_reg
self.random_state = random_state


RISE_config = RISE_parameters(
p_keep = np.arange(.1, 1, .1),
feature_res=np.arange(1, 10, 2),
n_masks=np.arange(1000, 4000, 500)
)


LIME_config = LIME_parameters(
num_samples=np.arange(1000, 4000, 500),
kernel_width=np.geomspace(0.01, 3, num=5),
distance_metric=[None], # will extend later
segmentation_fn=slic,
random_state = [42]
)


SHAP_config = SHAP_parameters(
nsamples=np.arange(1000, 4000, 500),
l1_reg=np.geomspace(.001, 1, num=5)
)
189 changes: 189 additions & 0 deletions relevance_maps_properties/experiments/run_experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import argparse
import dianna
import quantus
import json

import numpy as np

from dianna.utils.onnx_runner import SimpleModelRunner
from multiprocessing import Process
from numpy.typing import NDArray
from onnx import load
from onnx2keras import onnx_to_keras
from onnx.onnx_ml_pb2 import ModelProto
from pathlib import Path
from tqdm import tqdm
from time import time_ns
from typing import Callable, Union, Optional

# Local imports
from .hyperparameter_configs import LIME_parameters, RISE_parameters, SHAP_parameters, create_grid
from ..metrics.metrics import Incremental_deletion
from ..metrics import utils


class Experiments(object):
'''Class for the hyperparamter experiments.
All the necessary functionality with regards to the experiments is implemented
here.
NOTE: This method utilizes JSON as a means to store data, however, with the
data possibily scaling up to large size, we should look into mongoDB backend
or HD5 file storage.
'''
def __init__(self,
model: Union[ModelProto, str],
n_samples: int = 5,
preprocess_function: Optional[Callable] = None,
evaluator_kwargs: Optional[dict] = None,
model_kwargs: Optional[dict] = None,
**kwargs):

# Model preprocessing for cross-framework evaluation
if isinstance(model, str):
model = load(model)
if isinstance(model, ModelProto):
self.model = dianna.utils.get_function(model, preprocess_function=preprocess_function)
input_names, _ = utils.get_onnx_names(self.model)
self.keras_model = onnx_to_keras(self.model, input_names)
else:
raise TypeError('`model_or_function` failed to convert to ONNX.')

self.n_samples = n_samples
id_kwargs = dianna.utils.get_kwargs_applicable_to_function(Incremental_deletion.__init__, evaluator_kwargs)
quantus_kwargs = dianna.utils.get_kwargs_applicable_to_function(quantus.AvgSensitivity.__init__, evaluator_kwargs)

self.incr_del = Incremental_deletion(self.model, **id_kwargs, **model_kwargs)
self.avg_sensitivity = quantus.AvgSensitivity(nr_samples=self.n_samples,
**quantus_kwargs)
self.max_sensitivity = quantus.MaxSensitivity(nr_samples=self.n_samples,
**quantus_kwargs)

def init_JSON_format(experiment_name: str, n_images: int, n_configs: int):
''' Return the hierarchical structure and metadata for the experiments data.
Returns the data format that `explain_evaluate_images` expects to dump the
results in. Currently JSON seems a good way
'''
output = {'experiment_name': experiment_name,
'image': [
{
'image_id': 0,
'imag_data': [],
'configs': [
{
'config_id': 0,
'config': [],
'salient_batch': [],
'incremental_deletion': {},
'avg_sensitivity': 0.,
'max_sensitivity': 0.,
'run_time': 0.,
}
] * n_configs
}
] * n_images
}
return output

def explain_evaluate_images(self,
output_file: Path,
data: NDArray,
method: str,
grid: list[dict],
n_samples: int = 5,
model_kwargs: Optional[dict] = None,
) -> None:
''' This function will run our explainers and evaluators.
'''
if output_file.suffix != '.json':
raise ValueError('`output_file` must end with `.json`.')

explainer = self._get_explain_func(method)
results = self.init_JSON_format(data.shape[0], len(grid))

for image_id, image_data in enumerate(tqdm(data, desc='Running Experiments')):
results['images'][image_id]
for config_id, explainer_params in enumerate(grid):
results['runs']['image_id'][image_id]['params_id'] = {}
salient_batch = np.empty((n_samples, *image_data.shape[:2]))

start_time = time_ns()
for i in range(n_samples):
salient_batch[i] = explainer(image_data, **explainer_params)
end_time = (time_ns() - start_time) / self.n_samples

# Compute metrics
y_batch = self.model(image_data, **model_kwargs).argmax()[np.newaxis, ...]
incr_del = self.incr_del(image_data,
salient_batch,
batch_size=self.batch_size,
**model_kwargs).pop('salient_batch')
avg_sensitiviy = self.avg_sensitivity(model=self.keras_model,
x_batch=salient_batch,
y_batch=y_batch,
batch_size=self.batch_size)
max_sensitivity = self.max_sensitivity(model=self.keras_model,
x_batch=image_data,
y_batch=y_batch,
batch_size=self.batch_size)

# Save results
results['images'][image_id]['configs'][config_id]['incremental_deletion'] = incr_del
results['images'][image_id]['configs'][config_id]['avg_sensitivity'] = avg_sensitiviy
results['images'][image_id]['configs'][config_id]['max_sensitiviy'] = max_sensitivity
results['run_time'] = end_time - start_time

# Write results to file
with open(output_file, 'w') as f_out:
json.dumps(results, f_out)

def _get_explain_func(method: str) -> Callable:
if not isinstance(method, str):
raise TypeError('Please provide `method` as type str')

if method.to_upper() == 'KERNELSHAP':
return utils.SHAP_postprocess
elif method.to_upper() == 'LIME':
return utils.LIME_postprocess
elif method.to_upper() == 'RISE':
return dianna.explain_image
else:
raise ValueError('''Given method is not supported, please choose between
KernelShap, RISE and LIME.''')


def pool_handler():
'''Extend support for distributed computing
This function should generate several processes such
that our code can be run in a distributed manner.
'''
raise NotImplementedError()


def main():
parser = argparse.ArgumentParser()

parser.add_argument('--model', type=str, required=True)
parser.add_argument('--data', type=int, required=True)
parser.add_argument('--method', type=str, required=True)
parser.add_argument('--step', type=int, required=True)
parser.add_argument('--out', type=str, required=True)
parser.add_argument('--batch_size', type=int, required=True)
parser.add_argument('--n_samples', type=int, default=5)

args = parser.parse_args()
kwargs = vars(args)

# TODO: make grid
# TODO: load in dataset

experiments = Experiments(kwargs.pop('model'), **kwargs)

proc = Process(target=experiments.explain_evaluate_images)
Experiments.explain_evaluate_images()

if __name__ == '__main__':
main()
11 changes: 11 additions & 0 deletions relevance_maps_properties/experiments/run_experiments.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash
#SBATCH --time=48:00:00
#SBATCH --gres=gpu:1
#SBATCH -C A4000

source ~/.bashrc
module load cuda11.2/toolkit
mamba activate embeddings
cd ~/scratch/explainable_embedding/
# Must run python file a module for the local imports to work.
python3 -m relevance_maps_properties.Experiments.run_experiments
Empty file.
Empty file.

0 comments on commit 16f4ab4

Please sign in to comment.