From ef75c0f7adfdacd77c49eab6f5c2ee16eb67e34a Mon Sep 17 00:00:00 2001 From: WillemSpek Date: Mon, 19 Jun 2023 11:36:01 +0200 Subject: [PATCH 1/4] Added documentation for run_experiments with minor changes --- .../experiments/__init__.py | 0 .../experiments/distance_metrics.py | 5 + .../experiments/hyperparameter_configs.py | 116 +++++++++ .../experiments/run_experiments.py | 233 ++++++++++++++++++ .../experiments/run_experiments.sh | 11 + 5 files changed, 365 insertions(+) create mode 100644 relevance_maps_properties/experiments/__init__.py create mode 100644 relevance_maps_properties/experiments/distance_metrics.py create mode 100644 relevance_maps_properties/experiments/hyperparameter_configs.py create mode 100644 relevance_maps_properties/experiments/run_experiments.py create mode 100644 relevance_maps_properties/experiments/run_experiments.sh diff --git a/relevance_maps_properties/experiments/__init__.py b/relevance_maps_properties/experiments/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/relevance_maps_properties/experiments/distance_metrics.py b/relevance_maps_properties/experiments/distance_metrics.py new file mode 100644 index 0000000..7d629b9 --- /dev/null +++ b/relevance_maps_properties/experiments/distance_metrics.py @@ -0,0 +1,5 @@ +import numpy as np + +from scipy.stats import wasserstein_distance +from numpy.typing import NDArray + diff --git a/relevance_maps_properties/experiments/hyperparameter_configs.py b/relevance_maps_properties/experiments/hyperparameter_configs.py new file mode 100644 index 0000000..c0f6518 --- /dev/null +++ b/relevance_maps_properties/experiments/hyperparameter_configs.py @@ -0,0 +1,116 @@ +import numpy as np + +from typing import Optional, Iterable +from skimage.segmentation import slic +from sklearn.model_selection import ParameterGrid + + +def create_grid(parameters: object) -> list: + ''' Convert parameter objects to a grid containing all possible parameter + combinations. + + Args: + parameters: Parameters to use in the grid + + Returns: All possible parameter combinations + ''' + return list(ParameterGrid(parameters.__dict__)) + + +class RISE_parameters(object): + '''Set up hyperparameters for RISE. + ''' + def __init__(self, + p_keep: Optional[Iterable] = None, + feature_res: Optional[Iterable] = None, + n_masks: Optional[Iterable] = None, + random_state: Optional[Iterable[int]] = None): + ''' + Args: + p_keep: probability to keep bit unmasked + feature_res: size of bitmask + n_masks: number of masks to use + random_state: random seed + ''' + self.p_keep=p_keep + self.feature_res = feature_res + self.n_masks = n_masks + self.random_state = random_state + + +class LIME_parameters(object): + '''Set up hyperparamters for LIME. + NOTE: LIME segments the image using quickshift which is statically impleneted in + their package. We should discuss if we want to make this segmentation modifiable as a + hyperparameter by chanigng the LIME implementation and trying out a different segmentation algo. + ''' + def __init__(self, + num_samples: Optional[Iterable] = None, + kernel_width: Optional[Iterable] = None, + feature_selection: Optional[Iterable] = None, + distance_metric: Optional[Iterable] = None, + segmentation_fn: Optional[Iterable] = None, + model_regressor: Optional[Iterable] = None, + random_state: Optional[Iterable] = None): + ''' + Args: + num_samples: amount of instances to perturb + kernel_width: width to use for kernel to compute proximity + feature_selection: feature selection algorithm to select a priori + distance_metric: distance metric used to compute proximity + segmentation_fn: Segmentation algorithm to obtain superpixels + model_regressor: Surrogate model to use + random_state: random seed + ''' + self.num_samples = num_samples + self.kernel_width = kernel_width + self.feature_selection = feature_selection + self.distance_metric = distance_metric + self.segmentation_fn = segmentation_fn + self.model_regressor = model_regressor + self.random_state = random_state + + +class SHAP_parameters(object): + ''' Set up hyperparameters for KernelSHAP.''' + def __init__(self, + nsamples: Optional[Iterable] = None, + background: Optional[Iterable]= None, + sigma: Optional[Iterable] = None, + l1_reg: Optional[Iterable] = None, + random_state: Optional[Iterable] = None): + ''' + Args: + nsamples: amount of combinations to use + background: background of masked image + sigma: gaussian kernel width + l1_reg: L1 regularization factor + random_state: random seed + ''' + self.nsamples = nsamples, + self.background = background + self.sigma = sigma + self.l1_reg = l1_reg + self.random_state = random_state + + +RISE_config = RISE_parameters( + p_keep = np.arange(.1, 1, .1), + feature_res=np.arange(1, 10, 2), + n_masks=np.arange(1000, 4000, 500) +) + + +LIME_config = LIME_parameters( + num_samples=np.arange(1000, 4000, 500), + kernel_width=np.geomspace(0.01, 3, num=5), + distance_metric=[None], # will extend later + segmentation_fn=slic, + random_state = [42] +) + + +SHAP_config = SHAP_parameters( + nsamples=np.arange(1000, 4000, 500), + l1_reg=np.geomspace(.001, 1, num=5) +) \ No newline at end of file diff --git a/relevance_maps_properties/experiments/run_experiments.py b/relevance_maps_properties/experiments/run_experiments.py new file mode 100644 index 0000000..19ae802 --- /dev/null +++ b/relevance_maps_properties/experiments/run_experiments.py @@ -0,0 +1,233 @@ +import argparse +import dianna +import quantus +import json + +import numpy as np + +from dianna.utils.onnx_runner import SimpleModelRunner +from multiprocessing import Process +from numpy.typing import NDArray +from onnx import load +from onnx2keras import onnx_to_keras +from onnx.onnx_ml_pb2 import ModelProto +from pathlib import Path +from tqdm import tqdm +from time import time_ns +from typing import Callable, Union, Optional + +# Local imports +from .hyperparameter_configs import SHAP_config, LIME_config, RISE_config, create_grid +from ..metrics.metrics import Incremental_deletion +from ..metrics import utils + + +class Experiments(object): + '''Class for the hyperparamter experiments. + + All the necessary functionality with regards to the experiments is implemented + here. + + NOTE: This method utilizes JSON as a means to store data, however, with the + data possibily scaling up to large size, we should look into mongoDB backend + or HD5 file storage. + ''' + def __init__(self, + model: Union[ModelProto, str, Path], + n_samples: int = 5, + preprocess_function: Optional[Callable] = None, + evaluator_kwargs: Optional[dict] = None, + model_kwargs: Optional[dict] = None, + **kwargs): + ''' + Args: + model: the black-box model + n_samples: Number of samples to use for evaluation + preprocess_function: Preprocess function for the model + evaluator_kwargs: Kwargs for evaluation methods + model_kwargs, kwargs for the black-box model + Raises: + TypeError: In case model type mismatched with expected tpyes + ''' + # Model preprocessing for cross-framework evaluation + if isinstance(model, (str, Path)): + model = load(model) + if isinstance(model, ModelProto): + self.model = dianna.utils.get_function(model, preprocess_function=preprocess_function) + input_names, _ = utils.get_onnx_names(self.model) + self.keras_model = onnx_to_keras(self.model, input_names) + else: + raise TypeError('`model_or_function` failed to convert to Keras.') + + self.n_samples = n_samples + id_kwargs = dianna.utils.get_kwargs_applicable_to_function(Incremental_deletion.__init__, evaluator_kwargs) + quantus_kwargs = dianna.utils.get_kwargs_applicable_to_function(quantus.AvgSensitivity.__init__, evaluator_kwargs) + + self.incr_del = Incremental_deletion(self.model, **id_kwargs, **model_kwargs) + self.avg_sensitivity = quantus.AvgSensitivity(nr_samples=self.n_samples, + **quantus_kwargs) + self.max_sensitivity = quantus.MaxSensitivity(nr_samples=self.n_samples, + **quantus_kwargs) + + def init_JSON_format(experiment_name: str, n_images: int, n_configs: int) -> dict: + ''' Return the hierarchical structure and metadata for the experiments data. + + Returns the data format that `explain_evaluate_images` expects to dump the + results in. + + Args: + experiment_name: Name for the experiment + n_images: Number of images to run the experiment on + n_configs: Number of hyperparameter configurations + Returns: + Base dictionary representing JSON structure as output format. + ''' + output = {'experiment_name': experiment_name, + 'image': [ + { + 'image_id': 0, + 'imag_data': [], + 'configs': [ + { + 'config_id': 0, + 'config': [], + 'salient_batch': [], + 'incremental_deletion': {}, + 'avg_sensitivity': 0., + 'max_sensitivity': 0., + 'run_time': 0., + } + ] * n_configs + } + ] * n_images + } + return output + + def explain_evaluate_images(self, + output_file: Union[str, Path], + data: NDArray, + method: str, + grid: list[dict], + save_between: int = 100, + model_kwargs: Optional[dict] = None + ) -> None: + ''' This function will run our explainers and evaluators. + + Args: + output_file: File to write the results to. + data: The image data to experiment on + method: The explainer method to use + grid: The grid of possible hyperparameter configurations + save_between: Save results for every save_between images + model_kwargs: Kwargs to use for the model + + + ''' + if output_file.suffix != '.json': + raise ValueError('`output_file` must end with `.json`.') + + explainer = self._get_explain_func(method) + results = self.init_JSON_format(data.shape[0], len(grid)) + + for image_id, image_data in enumerate(tqdm(data, desc='Running Experiments')): + results['images'][image_id] + for config_id, explainer_params in enumerate(grid): + results['runs']['image_id'][image_id]['params_id'] = {} + salient_batch = np.empty((self.n_samples, *image_data.shape[:2])) + + start_time = time_ns() + for i in range(self.n_samples): + salient_batch[i] = explainer(image_data, **explainer_params) + end_time = (time_ns() - start_time) / self.n_samples + + # Compute metrics + y_batch = self.model(image_data, **model_kwargs).argmax()[np.newaxis, ...] + incr_del = self.incr_del(image_data, + salient_batch, + batch_size=self.batch_size, + **model_kwargs).pop('salient_batch') + avg_sensitiviy = self.avg_sensitivity(model=self.keras_model, + x_batch=salient_batch, + y_batch=y_batch, + batch_size=self.batch_size) + max_sensitivity = self.max_sensitivity(model=self.keras_model, + x_batch=image_data, + y_batch=y_batch, + batch_size=self.batch_size) + + # Save results + results['images'][image_id]['configs'][config_id]['incremental_deletion'] = incr_del + results['images'][image_id]['configs'][config_id]['avg_sensitivity'] = avg_sensitiviy + results['images'][image_id]['configs'][config_id]['max_sensitiviy'] = max_sensitivity + results['run_time'] = end_time - start_time + + # Write imbetween result to file in case of runtime failures + if image_id % save_between == 0: + print(f"Backing up at iteration {image_id}") + with open(output_file, 'w') as f_out: + json.dumps(results, f_out) + + # Save final results. + with open(output_file, 'w') as f_out: + json.dumps(results, f_out) + + def _get_explain_func(method: str) -> Callable: + '''Helper func to return appropriate explain function for method. + + Args: + method: Name of explanation method + Returns: + A function that contians the explanation method with post-processing + ''' + if not isinstance(method, str): + raise TypeError('Please provide `method` as type str') + + if method.to_upper() == 'KERNELSHAP': + return utils.SHAP_postprocess + elif method.to_upper() == 'LIME': + return utils.LIME_postprocess + elif method.to_upper() == 'RISE': + return dianna.explain_image + else: + raise ValueError('''Given method is not supported, please choose between + KernelShap, RISE and LIME.''') + + +def pool_handler(): + '''Extend support for distributed computing + + This function should generate several processes such + that our code can be run in a distributed manner. + ''' + raise NotImplementedError() + + +def main(): + ''' Main function to run the experiments. + + All experiments are called here, its is configurable through + command-line arguments. + ''' + parser = argparse.ArgumentParser() + + parser.add_argument('--model', type=str, required=True) + parser.add_argument('--data', type=str, required=True) + parser.add_argument('--out', type=str, default='./') + parser.add_argument('--step', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=100) + parser.add_argument('--n_samples', type=int, default=5) + + args = parser.parse_args() + kwargs = vars(args) + + data = np.load(kwargs.pop('data')) + for method, grid in zip(['RISE', 'LIME', 'KernelSHAP'], + [RISE_config, LIME_config, SHAP_config]): + out = Path(kwargs.pop('out') / method / '.json') + experiments = Experiments(kwargs.pop('model') **kwargs) + proc = Process(target=experiments.explain_evaluate_images, + args=[out, data, method], kwargs=kwargs) + proc.start + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/relevance_maps_properties/experiments/run_experiments.sh b/relevance_maps_properties/experiments/run_experiments.sh new file mode 100644 index 0000000..bde4bde --- /dev/null +++ b/relevance_maps_properties/experiments/run_experiments.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +#SBATCH --time=48:00:00 +#SBATCH --gres=gpu:1 +#SBATCH -C A4000 + +source ~/.bashrc +module load cuda11.2/toolkit +mamba activate embeddings +cd ~/scratch/explainable_embedding/ +# Must run python file a module for the local imports to work. +python3 -m relevance_maps_properties.Experiments.run_experiments From 63c355c92a25374f5a845a334a7bd002b917e1f7 Mon Sep 17 00:00:00 2001 From: WillemSpek Date: Thu, 22 Jun 2023 15:06:03 +0200 Subject: [PATCH 2/4] Multiple bugfixes on run_experiments.p --- .../experiments/hyperparameter_configs.py | 11 +- .../experiments/run_experiments.py | 153 +++++++++++------- relevance_maps_properties/metrics/metrics.py | 6 +- relevance_maps_properties/metrics/utils.py | 14 +- requirements.in | 2 +- 5 files changed, 113 insertions(+), 73 deletions(-) diff --git a/relevance_maps_properties/experiments/hyperparameter_configs.py b/relevance_maps_properties/experiments/hyperparameter_configs.py index c0f6518..28e3e3a 100644 --- a/relevance_maps_properties/experiments/hyperparameter_configs.py +++ b/relevance_maps_properties/experiments/hyperparameter_configs.py @@ -13,8 +13,11 @@ def create_grid(parameters: object) -> list: parameters: Parameters to use in the grid Returns: All possible parameter combinations - ''' - return list(ParameterGrid(parameters.__dict__)) + ''' + params = parameters.__dict__ + params = {k: params[k] for k in params.keys() + if params[k] is not None} + return list(ParameterGrid(params)) class RISE_parameters(object): @@ -96,7 +99,7 @@ def __init__(self, RISE_config = RISE_parameters( p_keep = np.arange(.1, 1, .1), - feature_res=np.arange(1, 10, 2), + feature_res=np.arange(2, 11, 2), n_masks=np.arange(1000, 4000, 500) ) @@ -105,7 +108,7 @@ def __init__(self, num_samples=np.arange(1000, 4000, 500), kernel_width=np.geomspace(0.01, 3, num=5), distance_metric=[None], # will extend later - segmentation_fn=slic, + segmentation_fn=[slic], random_state = [42] ) diff --git a/relevance_maps_properties/experiments/run_experiments.py b/relevance_maps_properties/experiments/run_experiments.py index 19ae802..c3ba251 100644 --- a/relevance_maps_properties/experiments/run_experiments.py +++ b/relevance_maps_properties/experiments/run_experiments.py @@ -2,11 +2,10 @@ import dianna import quantus import json +import warnings import numpy as np -from dianna.utils.onnx_runner import SimpleModelRunner -from multiprocessing import Process from numpy.typing import NDArray from onnx import load from onnx2keras import onnx_to_keras @@ -15,12 +14,17 @@ from tqdm import tqdm from time import time_ns from typing import Callable, Union, Optional +from functools import partialmethod # Local imports from .hyperparameter_configs import SHAP_config, LIME_config, RISE_config, create_grid from ..metrics.metrics import Incremental_deletion from ..metrics import utils +# Silence warnings and tqdm progress bars by default +tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) +warnings.filterwarnings("ignore") + class Experiments(object): '''Class for the hyperparamter experiments. @@ -36,8 +40,6 @@ def __init__(self, model: Union[ModelProto, str, Path], n_samples: int = 5, preprocess_function: Optional[Callable] = None, - evaluator_kwargs: Optional[dict] = None, - model_kwargs: Optional[dict] = None, **kwargs): ''' Args: @@ -50,26 +52,27 @@ def __init__(self, TypeError: In case model type mismatched with expected tpyes ''' # Model preprocessing for cross-framework evaluation - if isinstance(model, (str, Path)): - model = load(model) - if isinstance(model, ModelProto): - self.model = dianna.utils.get_function(model, preprocess_function=preprocess_function) - input_names, _ = utils.get_onnx_names(self.model) - self.keras_model = onnx_to_keras(self.model, input_names) - else: - raise TypeError('`model_or_function` failed to convert to Keras.') + self.model = dianna.utils.get_function(model, + preprocess_function=preprocess_function) + onnx_model = load(model) + input_names, _ = utils.get_onnx_names(onnx_model) + self.keras_model = onnx_to_keras(onnx_model, input_names, + name_policy='renumerate', verbose=False) self.n_samples = n_samples - id_kwargs = dianna.utils.get_kwargs_applicable_to_function(Incremental_deletion.__init__, evaluator_kwargs) - quantus_kwargs = dianna.utils.get_kwargs_applicable_to_function(quantus.AvgSensitivity.__init__, evaluator_kwargs) + id_kwargs = dianna.utils.get_kwargs_applicable_to_function( + Incremental_deletion.__init__, kwargs) + quantus_kwargs = dianna.utils.get_kwargs_applicable_to_function( + quantus.AvgSensitivity.__init__,kwargs) - self.incr_del = Incremental_deletion(self.model, **id_kwargs, **model_kwargs) + self.incr_del = Incremental_deletion(self.model, **id_kwargs) self.avg_sensitivity = quantus.AvgSensitivity(nr_samples=self.n_samples, **quantus_kwargs) self.max_sensitivity = quantus.MaxSensitivity(nr_samples=self.n_samples, **quantus_kwargs) - def init_JSON_format(experiment_name: str, n_images: int, n_configs: int) -> dict: + def init_JSON_format(self, experiment_name: str, + n_images: int, n_configs: int) -> dict: ''' Return the hierarchical structure and metadata for the experiments data. Returns the data format that `explain_evaluate_images` expects to dump the @@ -83,7 +86,7 @@ def init_JSON_format(experiment_name: str, n_images: int, n_configs: int) -> dic Base dictionary representing JSON structure as output format. ''' output = {'experiment_name': experiment_name, - 'image': [ + 'images': [ { 'image_id': 0, 'imag_data': [], @@ -96,10 +99,10 @@ def init_JSON_format(experiment_name: str, n_images: int, n_configs: int) -> dic 'avg_sensitivity': 0., 'max_sensitivity': 0., 'run_time': 0., - } - ] * n_configs - } - ] * n_images + } for _ in range(n_configs) + ] + } for _ in range(n_images) + ] } return output @@ -108,8 +111,9 @@ def explain_evaluate_images(self, data: NDArray, method: str, grid: list[dict], + batch_size=64, save_between: int = 100, - model_kwargs: Optional[dict] = None + model_kwargs: dict = {} ) -> None: ''' This function will run our explainers and evaluators. @@ -120,57 +124,74 @@ def explain_evaluate_images(self, grid: The grid of possible hyperparameter configurations save_between: Save results for every save_between images model_kwargs: Kwargs to use for the model - - ''' if output_file.suffix != '.json': raise ValueError('`output_file` must end with `.json`.') + if data.ndim != 4: + raise ValueError('Dimension of `data` must be 4') explainer = self._get_explain_func(method) - results = self.init_JSON_format(data.shape[0], len(grid)) - - for image_id, image_data in enumerate(tqdm(data, desc='Running Experiments')): - results['images'][image_id] - for config_id, explainer_params in enumerate(grid): - results['runs']['image_id'][image_id]['params_id'] = {} - salient_batch = np.empty((self.n_samples, *image_data.shape[:2])) + results = self.init_JSON_format(method + 'Experiment', data.shape[0], len(grid)) + run_times = np.empty(self.n_samples) + + salient_batch = np.empty((self.n_samples, *data.shape[1:3])) + + for image_id, image_data in enumerate(tqdm(data, desc='Running Experiments', + disable=False, position=0)): + label = self.model(image_data[np.newaxis, ...], + **model_kwargs).argmax()[np.newaxis, ...] + for config_id, explainer_params in enumerate(tqdm(grid, desc='Trying out configurations', + disable=False, position=1, + leave=True)): + # TODO: Ensure this block happens outside this VERY expensive loop + explainer_params['labels'] = label + explainer_params['model_or_function'] = self.model + explainer_params['input_data'] = image_data - start_time = time_ns() for i in range(self.n_samples): - salient_batch[i] = explainer(image_data, **explainer_params) - end_time = (time_ns() - start_time) / self.n_samples + start_time = time_ns() + salient_batch[i] = explainer(**explainer_params) + run_times[i] = time_ns() - start_time # Compute metrics - y_batch = self.model(image_data, **model_kwargs).argmax()[np.newaxis, ...] incr_del = self.incr_del(image_data, salient_batch, - batch_size=self.batch_size, - **model_kwargs).pop('salient_batch') + batch_size=batch_size, + **model_kwargs) + del incr_del['salient_scores'] + del incr_del['random_scores'] + + avg_sensitiviy = self.avg_sensitivity(model=self.keras_model, - x_batch=salient_batch, - y_batch=y_batch, - batch_size=self.batch_size) + x_batch=image_data[np.newaxis, ...], + y_batch=label, + batch_size=batch_size, + explain_func=explainer, + explain_func_kwargs=explainer_params) max_sensitivity = self.max_sensitivity(model=self.keras_model, - x_batch=image_data, - y_batch=y_batch, - batch_size=self.batch_size) + x_batch=image_data[np.newaxis, ...], + y_batch=label, + batch_size=batch_size, + explain_func=explainer, + explain_func_kwargs=explainer_params) # Save results results['images'][image_id]['configs'][config_id]['incremental_deletion'] = incr_del results['images'][image_id]['configs'][config_id]['avg_sensitivity'] = avg_sensitiviy results['images'][image_id]['configs'][config_id]['max_sensitiviy'] = max_sensitivity - results['run_time'] = end_time - start_time + results['images'][image_id]['configs'][config_id]['run_time'] = np.median(run_times) - # Write imbetween result to file in case of runtime failures - if image_id % save_between == 0: - print(f"Backing up at iteration {image_id}") - with open(output_file, 'w') as f_out: - json.dumps(results, f_out) + # Write imbetween result to file in case of runtime failures + if image_id % save_between == 0: + print(f"Backing up at iteration {image_id}") + with open(output_file, 'w') as fp: + json.dump(results, fp) # Save final results. - with open(output_file, 'w') as f_out: - json.dumps(results, f_out) + with open(output_file, 'w') as fp: + json.dump(results, fp) + @staticmethod def _get_explain_func(method: str) -> Callable: '''Helper func to return appropriate explain function for method. @@ -182,12 +203,12 @@ def _get_explain_func(method: str) -> Callable: if not isinstance(method, str): raise TypeError('Please provide `method` as type str') - if method.to_upper() == 'KERNELSHAP': + if method.upper() == 'KERNELSHAP': return utils.SHAP_postprocess - elif method.to_upper() == 'LIME': + elif method.upper() == 'LIME': return utils.LIME_postprocess - elif method.to_upper() == 'RISE': - return dianna.explain_image + elif method.upper() == 'RISE': + return utils.RISE_postprocess else: raise ValueError('''Given method is not supported, please choose between KernelShap, RISE and LIME.''') @@ -202,6 +223,12 @@ def pool_handler(): raise NotImplementedError() +def load_MNIST(data: Union[str, Path]) -> NDArray: + f_store = np.load(data) + images = f_store['X_test'].astype(np.float32) + return images.reshape([-1, 28, 28, 1]) / 255 + + def main(): ''' Main function to run the experiments. @@ -219,15 +246,17 @@ def main(): args = parser.parse_args() kwargs = vars(args) + model = str(Path(kwargs.pop('model')).absolute()) + out = kwargs.pop('out') - data = np.load(kwargs.pop('data')) - for method, grid in zip(['RISE', 'LIME', 'KernelSHAP'], + data = load_MNIST(kwargs.pop('data')) + for method, config in zip(['RISE', 'LIME', 'KernelSHAP'], [RISE_config, LIME_config, SHAP_config]): - out = Path(kwargs.pop('out') / method / '.json') - experiments = Experiments(kwargs.pop('model') **kwargs) - proc = Process(target=experiments.explain_evaluate_images, - args=[out, data, method], kwargs=kwargs) - proc.start + grid = create_grid(config) + out = Path(out) / (method + '.json') + experiments = Experiments(model, **kwargs) + kwargs = dianna.utils.get_kwargs_applicable_to_function(experiments.explain_evaluate_images, kwargs) + experiments.explain_evaluate_images(out, data, method, grid, **kwargs) if __name__ == '__main__': main() \ No newline at end of file diff --git a/relevance_maps_properties/metrics/metrics.py b/relevance_maps_properties/metrics/metrics.py index 9d43b2b..2a0bfb6 100644 --- a/relevance_maps_properties/metrics/metrics.py +++ b/relevance_maps_properties/metrics/metrics.py @@ -2,7 +2,6 @@ import dianna import numpy as np -import seaborn as sns import matplotlib.pyplot as plt from dianna.utils import get_function @@ -66,9 +65,8 @@ def __call__(self, impute_method, **model_kwargs) x = np.arange(salient_scores.size) / salient_scores.size salient_auc = auc(x, salient_scores) - if not 'salient_scores' in results: - results['salient_scores'].append(salient_scores) - results['salient_auc'].append(salient_auc) + results['salient_scores'].append(salient_scores) + results['salient_auc'].append(salient_auc) if evaluate_random_baseline: for _ in range(salient_batch.shape[0]): diff --git a/relevance_maps_properties/metrics/utils.py b/relevance_maps_properties/metrics/utils.py index c2162ab..cf82c53 100644 --- a/relevance_maps_properties/metrics/utils.py +++ b/relevance_maps_properties/metrics/utils.py @@ -29,8 +29,8 @@ def LIME_postprocess(*args, **kwargs) -> NDArray: DIANNA yields: list[NDArray[(Any, Any), Any]] Quantus expects: NDArray((Any, Any, Any), Any) ''' - results = dianna.explain_image(method='LIME', *args, **kwargs) - return np.array(results)[0][None, ...] + results = dianna.explain_image(*args, method='LIME', **kwargs) + return np.array(results) def SHAP_postprocess(label, *args, **kwargs) -> NDArray: @@ -44,6 +44,16 @@ def SHAP_postprocess(label, *args, **kwargs) -> NDArray: saliences = list(_fill_segmentation(shapley_values[label][0], segments_slic)) return np.array(saliences)[np.newaxis, ..., np.newaxis] +def RISE_postprocess(*args, **kwargs) -> NDArray: + ''' + Post-process the output of DIANNA LIME in according to what Quantus expects. + + DIANNA yields: list[NDArray[(Any, Any), Any]] + Quantus expects: NDArray((Any, Any, Any), Any) + ''' + results = dianna.explain_image(*args, method='RISE', **kwargs) + return np.array(results) + def _fill_segmentation(values: NDArray, segmentation: NDArray) -> NDArray: ''' diff --git a/requirements.in b/requirements.in index 9864dd6..920039a 100644 --- a/requirements.in +++ b/requirements.in @@ -1,6 +1,6 @@ --find-links https://download.pytorch.org/whl/torch_stable.html -dianna @ git+https://github.com/dianna-ai/dianna.git # Pull DIANNA from dev branch +dianna @ git+https://github.com/dianna-ai/dianna.git # Pull DIANNA from main branch torch==1.9.0+cpu captum>=0.4.0 onnx2keras>=0.0.24 From 8d6fe163b2104b2711594f619fb7126b125da099 Mon Sep 17 00:00:00 2001 From: WillemSpek Date: Fri, 23 Jun 2023 15:17:34 +0200 Subject: [PATCH 3/4] added gpu acceleration --- .../experiments/run_experiments.py | 42 +++++++------- .../experiments/runners.py | 56 +++++++++++++++++++ 2 files changed, 78 insertions(+), 20 deletions(-) create mode 100644 relevance_maps_properties/experiments/runners.py diff --git a/relevance_maps_properties/experiments/run_experiments.py b/relevance_maps_properties/experiments/run_experiments.py index c3ba251..db79554 100644 --- a/relevance_maps_properties/experiments/run_experiments.py +++ b/relevance_maps_properties/experiments/run_experiments.py @@ -1,25 +1,26 @@ import argparse -import dianna -import quantus import json import warnings +from functools import partialmethod +from pathlib import Path +from time import time_ns +from typing import Callable, Optional, Union +import dianna import numpy as np - +import quantus from numpy.typing import NDArray -from onnx import load -from onnx2keras import onnx_to_keras +from onnx import load from onnx.onnx_ml_pb2 import ModelProto -from pathlib import Path +from onnx2keras import onnx_to_keras from tqdm import tqdm -from time import time_ns -from typing import Callable, Union, Optional -from functools import partialmethod -# Local imports -from .hyperparameter_configs import SHAP_config, LIME_config, RISE_config, create_grid -from ..metrics.metrics import Incremental_deletion +from .runners import ModelRunner from ..metrics import utils +from ..metrics.metrics import Incremental_deletion + +# Local imports +from .hyperparameter_configs import LIME_config, RISE_config, SHAP_config, create_grid # Silence warnings and tqdm progress bars by default tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) @@ -52,8 +53,7 @@ def __init__(self, TypeError: In case model type mismatched with expected tpyes ''' # Model preprocessing for cross-framework evaluation - self.model = dianna.utils.get_function(model, - preprocess_function=preprocess_function) + self.model = ModelRunner(model) onnx_model = load(model) input_names, _ = utils.get_onnx_names(onnx_model) self.keras_model = onnx_to_keras(onnx_model, input_names, @@ -111,6 +111,7 @@ def explain_evaluate_images(self, data: NDArray, method: str, grid: list[dict], + device: str = 'cpu', batch_size=64, save_between: int = 100, model_kwargs: dict = {} @@ -130,10 +131,11 @@ def explain_evaluate_images(self, if data.ndim != 4: raise ValueError('Dimension of `data` must be 4') + if device == 'gpu': + self.model.__call__ = partialmethod(self.model.__call__, device=1) explainer = self._get_explain_func(method) results = self.init_JSON_format(method + 'Experiment', data.shape[0], len(grid)) run_times = np.empty(self.n_samples) - salient_batch = np.empty((self.n_samples, *data.shape[1:3])) for image_id, image_data in enumerate(tqdm(data, desc='Running Experiments', @@ -150,7 +152,8 @@ def explain_evaluate_images(self, for i in range(self.n_samples): start_time = time_ns() - salient_batch[i] = explainer(**explainer_params) + salient_batch[i] = explainer(batch_size=500, + **explainer_params) run_times[i] = time_ns() - start_time # Compute metrics @@ -158,10 +161,6 @@ def explain_evaluate_images(self, salient_batch, batch_size=batch_size, **model_kwargs) - del incr_del['salient_scores'] - del incr_del['random_scores'] - - avg_sensitiviy = self.avg_sensitivity(model=self.keras_model, x_batch=image_data[np.newaxis, ...], y_batch=label, @@ -176,6 +175,8 @@ def explain_evaluate_images(self, explain_func_kwargs=explainer_params) # Save results + del incr_del['salient_scores'] + del incr_del['random_scores'] results['images'][image_id]['configs'][config_id]['incremental_deletion'] = incr_del results['images'][image_id]['configs'][config_id]['avg_sensitivity'] = avg_sensitiviy results['images'][image_id]['configs'][config_id]['max_sensitiviy'] = max_sensitivity @@ -242,6 +243,7 @@ def main(): parser.add_argument('--out', type=str, default='./') parser.add_argument('--step', type=int, default=2) parser.add_argument('--batch_size', type=int, default=100) + parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--n_samples', type=int, default=5) args = parser.parse_args() diff --git a/relevance_maps_properties/experiments/runners.py b/relevance_maps_properties/experiments/runners.py new file mode 100644 index 0000000..fd39fd5 --- /dev/null +++ b/relevance_maps_properties/experiments/runners.py @@ -0,0 +1,56 @@ +import onnxruntime as ort + +from pathlib import Path +from typing import Union, Optional, Callable + + +class ModelRunner: + """Runs an onnx model with a set of inputs and outputs.""" + def __init__(self, + filename: Union[str, Path], + preprocess_function: Optional[Callable] = None + ): + """Generates function to run ONNX model with one set of inputs and outputs. + + Args: + filename: Path to ONNX model on disk + preprocess_function: Function to preprocess input data with + + Returns: + function + + Examples: + >>> runner = SimpleModelRunner('path_to_model.onnx') + >>> predictions = runner(input_data) + """ + self.filename = filename + self.preprocess_function = preprocess_function + + def __call__(self, input_data, device=0): + """Get ONNX predictions.""" + EP_list = self._set_EP(device) + sess = ort.InferenceSession(self.filename, providers=EP_list) + input_name = sess.get_inputs()[0].name + output_name = sess.get_outputs()[0].name + + if self.preprocess_function is not None: + input_data = self.preprocess_function(input_data) + + onnx_input = {input_name: input_data} + pred_onnx = sess.run([output_name], onnx_input)[0] + return pred_onnx + + @staticmethod + def _set_EP(device: int) -> list: + if device == 0: # CPU + return ['CPUExecutionProvider'] + elif device == 1: # GPU + return [ + ("CUDAExecutionProvider", + {"cudnn_conv_algo_search": "DEFAULT"} + ), + "TensorRTExecutionProvider", + " CPUExecutionProvider" + ] + else: + raise ValueError('Device has to be 0 (CPU) or 1 (GPU).') From b4d23e91c293334ce5ba65f9504d0d157bed2966 Mon Sep 17 00:00:00 2001 From: WillemSpek Date: Thu, 29 Jun 2023 11:56:40 +0200 Subject: [PATCH 4/4] further updated experiments --- .../experiments/config_wrappers.py | 26 ++++ .../experiments/distance_metrics.py | 5 - .../experiments/hyperparameter_configs.py | 96 +++++++++---- .../experiments/run_experiments.py | 136 ++++++++---------- .../experiments/runners.py | 33 +++-- relevance_maps_properties/metrics/metrics.py | 20 +-- relevance_maps_properties/metrics/utils.py | 11 +- 7 files changed, 188 insertions(+), 139 deletions(-) create mode 100644 relevance_maps_properties/experiments/config_wrappers.py delete mode 100644 relevance_maps_properties/experiments/distance_metrics.py diff --git a/relevance_maps_properties/experiments/config_wrappers.py b/relevance_maps_properties/experiments/config_wrappers.py new file mode 100644 index 0000000..530e7f2 --- /dev/null +++ b/relevance_maps_properties/experiments/config_wrappers.py @@ -0,0 +1,26 @@ +from typing import Any +import numpy as np + +from scipy.stats import wasserstein_distance +from numpy.typing import NDArray +from skimage.segmentation import slic +from typing import Optional + + +class Slic_Wrapper(): + def __init__(self, + n_segments: int = 10, + compactness: float = 10., + sigma: float = 0.): + self.n_segments= n_segments + self.compactness = compactness + self.sigma = sigma + + def __call__(self, image): + return slic(image, + n_segments=self.n_segments, + compactness=self.compactness, + sigma = self.sigma) + + def __repr__(self): + return f'slic(n_segments={self.n_segments}, compactness={self.compactness}, sigma={self.sigma})' \ No newline at end of file diff --git a/relevance_maps_properties/experiments/distance_metrics.py b/relevance_maps_properties/experiments/distance_metrics.py deleted file mode 100644 index 7d629b9..0000000 --- a/relevance_maps_properties/experiments/distance_metrics.py +++ /dev/null @@ -1,5 +0,0 @@ -import numpy as np - -from scipy.stats import wasserstein_distance -from numpy.typing import NDArray - diff --git a/relevance_maps_properties/experiments/hyperparameter_configs.py b/relevance_maps_properties/experiments/hyperparameter_configs.py index 28e3e3a..2a72e5c 100644 --- a/relevance_maps_properties/experiments/hyperparameter_configs.py +++ b/relevance_maps_properties/experiments/hyperparameter_configs.py @@ -1,25 +1,60 @@ +from collections.abc import Mapping, Sequence import numpy as np -from typing import Optional, Iterable -from skimage.segmentation import slic +from typing import Mapping, Optional, Iterable, Sequence, Union from sklearn.model_selection import ParameterGrid - - -def create_grid(parameters: object) -> list: - ''' Convert parameter objects to a grid containing all possible parameter - combinations. - - Args: - parameters: Parameters to use in the grid - - Returns: All possible parameter combinations - ''' - params = parameters.__dict__ - params = {k: params[k] for k in params.keys() - if params[k] is not None} - return list(ParameterGrid(params)) - - +from sklearn.linear_model import Ridge + +from .config_wrappers import Slic_Wrapper + + +class ParamGrid(ParameterGrid): + '''Wrapper for ParameterGrid from sklearn.model_selection''' + def __init__(self, param_grid: Union[Sequence, Mapping]) -> None: + cleaned_grid = {} + for key in param_grid: + if param_grid[key] is None: + continue + elif isinstance(param_grid[key], (np.ndarray, np.generic)): + cleaned_grid[key] = param_grid[key].tolist() + else: + cleaned_grid[key] = param_grid[key] + super().__init__(cleaned_grid) + + def __getitem__(self, ind: int) -> dict[str, list[str, int, float]]: + '''Slight modifitcation of the sklearn.model_selection.ParameterGrid implementation + + Tries to get the representation of non strings, floats and ints in order + to make this data serializable.''' + for sub_grid in self.param_grid: + if not sub_grid: + if ind == 0: + return {} + else: + ind -= 1 + continue + + # Reverse so most frequent cycling parameter comes first + keys, values_lists = zip(*sorted(sub_grid.items())[::-1]) + sizes = [len(v_list) for v_list in values_lists] + total = np.product(sizes) + + if ind >= total: + # Try the next grid + ind -= total + else: + out = {} + for key, v_list, n in zip(keys, values_lists, sizes): + ind, offset = divmod(ind, n) + val = v_list[offset] + if not isinstance(val, (str, float, int)): + val = str(val) + out[key] = val + return out + + raise IndexError("ParameterGrid index out of range") + + class RISE_parameters(object): '''Set up hyperparameters for RISE. ''' @@ -98,22 +133,25 @@ def __init__(self, RISE_config = RISE_parameters( + n_masks=np.arange(200, 2000, 400), p_keep = np.arange(.1, 1, .1), - feature_res=np.arange(2, 11, 2), - n_masks=np.arange(1000, 4000, 500) + feature_res=np.arange(3, 16, 3), + random_state=[42] ) LIME_config = LIME_parameters( - num_samples=np.arange(1000, 4000, 500), - kernel_width=np.geomspace(0.01, 3, num=5), - distance_metric=[None], # will extend later - segmentation_fn=[slic], - random_state = [42] + num_samples=np.arange(20, 200, 40), + kernel_width=np.geomspace(0.1, 3, num=5), + distance_metric=None, # will extend later + segmentation_fn=[Slic_Wrapper(n_segments=n) for n in range(10, 60,10)], + model_regressor=[Ridge(alpha=a) for a in [0, *np.geomspace(0.05, 3, num=4)]] + # random_state = [42] ) SHAP_config = SHAP_parameters( - nsamples=np.arange(1000, 4000, 500), - l1_reg=np.geomspace(.001, 1, num=5) -) \ No newline at end of file + nsamples=np.arange(20, 200, 40), + l1_reg=[0, *np.geomspace(0.05, 3, num=4)], + random_state=[42] +) diff --git a/relevance_maps_properties/experiments/run_experiments.py b/relevance_maps_properties/experiments/run_experiments.py index db79554..5f8fee0 100644 --- a/relevance_maps_properties/experiments/run_experiments.py +++ b/relevance_maps_properties/experiments/run_experiments.py @@ -1,12 +1,14 @@ import argparse import json import warnings +from copy import copy from functools import partialmethod from pathlib import Path from time import time_ns from typing import Callable, Optional, Union import dianna +import matplotlib.pyplot as plt import numpy as np import quantus from numpy.typing import NDArray @@ -15,16 +17,16 @@ from onnx2keras import onnx_to_keras from tqdm import tqdm -from .runners import ModelRunner from ..metrics import utils from ..metrics.metrics import Incremental_deletion # Local imports -from .hyperparameter_configs import LIME_config, RISE_config, SHAP_config, create_grid +from .hyperparameter_configs import LIME_config, RISE_config, SHAP_config, ParamGrid +from .runners import ModelRunner -# Silence warnings and tqdm progress bars by default +# Silence imported progress bars and warnings tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) -warnings.filterwarnings("ignore") +warnings.simplefilter("ignore") class Experiments(object): @@ -53,7 +55,7 @@ def __init__(self, TypeError: In case model type mismatched with expected tpyes ''' # Model preprocessing for cross-framework evaluation - self.model = ModelRunner(model) + self.model = ModelRunner(model, preprocess_function=preprocess_function) onnx_model = load(model) input_names, _ = utils.get_onnx_names(onnx_model) self.keras_model = onnx_to_keras(onnx_model, input_names, @@ -71,8 +73,9 @@ def __init__(self, self.max_sensitivity = quantus.MaxSensitivity(nr_samples=self.n_samples, **quantus_kwargs) - def init_JSON_format(self, experiment_name: str, - n_images: int, n_configs: int) -> dict: + @staticmethod + def init_JSON_format(experiment_name: str, + n_configs: int) -> dict: ''' Return the hierarchical structure and metadata for the experiments data. Returns the data format that `explain_evaluate_images` expects to dump the @@ -80,40 +83,34 @@ def init_JSON_format(self, experiment_name: str, Args: experiment_name: Name for the experiment - n_images: Number of images to run the experiment on n_configs: Number of hyperparameter configurations Returns: Base dictionary representing JSON structure as output format. ''' output = {'experiment_name': experiment_name, - 'images': [ - { - 'image_id': 0, - 'imag_data': [], - 'configs': [ - { - 'config_id': 0, - 'config': [], - 'salient_batch': [], - 'incremental_deletion': {}, - 'avg_sensitivity': 0., - 'max_sensitivity': 0., - 'run_time': 0., - } for _ in range(n_configs) - ] - } for _ in range(n_images) + 'image_id': 0, + 'image_data': [], + 'model_scores': [], + 'configs': [ + { + 'config_id': 0, + 'config': [], + 'sensitivity': [], + 'run_time': 0., + 'incremental_deletion': {}, + 'salient_batch': [], + } for _ in range(n_configs) ] - } + } return output def explain_evaluate_images(self, - output_file: Union[str, Path], + output_folder: Union[str, Path], data: NDArray, method: str, grid: list[dict], device: str = 'cpu', batch_size=64, - save_between: int = 100, model_kwargs: dict = {} ) -> None: ''' This function will run our explainers and evaluators. @@ -126,15 +123,11 @@ def explain_evaluate_images(self, save_between: Save results for every save_between images model_kwargs: Kwargs to use for the model ''' - if output_file.suffix != '.json': - raise ValueError('`output_file` must end with `.json`.') if data.ndim != 4: raise ValueError('Dimension of `data` must be 4') - if device == 'gpu': - self.model.__call__ = partialmethod(self.model.__call__, device=1) explainer = self._get_explain_func(method) - results = self.init_JSON_format(method + 'Experiment', data.shape[0], len(grid)) + results = self.init_JSON_format(method + '_Experiment', len(grid)) run_times = np.empty(self.n_samples) salient_batch = np.empty((self.n_samples, *data.shape[1:3])) @@ -142,18 +135,19 @@ def explain_evaluate_images(self, disable=False, position=0)): label = self.model(image_data[np.newaxis, ...], **model_kwargs).argmax()[np.newaxis, ...] - for config_id, explainer_params in enumerate(tqdm(grid, desc='Trying out configurations', + for config_id, config in enumerate(tqdm(grid, desc='Trying out configurations', disable=False, position=1, leave=True)): - # TODO: Ensure this block happens outside this VERY expensive loop + + explainer_params = copy(config) # Prevent in-place modification of grid explainer_params['labels'] = label explainer_params['model_or_function'] = self.model explainer_params['input_data'] = image_data + explainer_params['batch_size'] = batch_size for i in range(self.n_samples): start_time = time_ns() - salient_batch[i] = explainer(batch_size=500, - **explainer_params) + salient_batch[i] = explainer(**explainer_params) run_times[i] = time_ns() - start_time # Compute metrics @@ -161,36 +155,24 @@ def explain_evaluate_images(self, salient_batch, batch_size=batch_size, **model_kwargs) - avg_sensitiviy = self.avg_sensitivity(model=self.keras_model, - x_batch=image_data[np.newaxis, ...], - y_batch=label, - batch_size=batch_size, - explain_func=explainer, - explain_func_kwargs=explainer_params) - max_sensitivity = self.max_sensitivity(model=self.keras_model, - x_batch=image_data[np.newaxis, ...], - y_batch=label, - batch_size=batch_size, - explain_func=explainer, - explain_func_kwargs=explainer_params) - + sensitivity = self.avg_sensitivity(model=self.keras_model, + x_batch=image_data[np.newaxis, ...], + y_batch=label, + batch_size=batch_size, + explain_func=explainer, + explain_func_kwargs=explainer_params) + # Save results - del incr_del['salient_scores'] - del incr_del['random_scores'] - results['images'][image_id]['configs'][config_id]['incremental_deletion'] = incr_del - results['images'][image_id]['configs'][config_id]['avg_sensitivity'] = avg_sensitiviy - results['images'][image_id]['configs'][config_id]['max_sensitiviy'] = max_sensitivity - results['images'][image_id]['configs'][config_id]['run_time'] = np.median(run_times) - - # Write imbetween result to file in case of runtime failures - if image_id % save_between == 0: - print(f"Backing up at iteration {image_id}") - with open(output_file, 'w') as fp: - json.dump(results, fp) + results['configs'][config_id]['config'] = grid[config_id] + results['configs'][config_id]['salient_batch'] = salient_batch.tolist() + results['configs'][config_id]['incremental_deletion'] = incr_del + results['configs'][config_id]['sensitivity'] = sensitivity + results['configs'][config_id]['run_time'] = np.median(run_times) - # Save final results. - with open(output_file, 'w') as fp: - json.dump(results, fp) + # Savel results + output_file = Path(output_folder) / ('image_' + str(image_id) + '.json') + with open(output_file, 'w') as fp: + json.dump(results, fp, indent=4) @staticmethod def _get_explain_func(method: str) -> Callable: @@ -227,7 +209,8 @@ def pool_handler(): def load_MNIST(data: Union[str, Path]) -> NDArray: f_store = np.load(data) images = f_store['X_test'].astype(np.float32) - return images.reshape([-1, 28, 28, 1]) / 255 + images = images.reshape([-1, 28, 28, 1]) / 255 + return images[:10] def main(): @@ -237,28 +220,27 @@ def main(): command-line arguments. ''' parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, required=True) parser.add_argument('--data', type=str, required=True) + parser.add_argument('--method', type=str, required=True) parser.add_argument('--out', type=str, default='./') parser.add_argument('--step', type=int, default=2) parser.add_argument('--batch_size', type=int, default=100) parser.add_argument('--device', type=str, default='cpu') - parser.add_argument('--n_samples', type=int, default=5) - args = parser.parse_args() kwargs = vars(args) + + out = Path(kwargs.pop('out')) + if not out.exists(): + raise ValueError('Please specify an existing path on the --out parameter') model = str(Path(kwargs.pop('model')).absolute()) - out = kwargs.pop('out') - data = load_MNIST(kwargs.pop('data')) - for method, config in zip(['RISE', 'LIME', 'KernelSHAP'], - [RISE_config, LIME_config, SHAP_config]): - grid = create_grid(config) - out = Path(out) / (method + '.json') - experiments = Experiments(model, **kwargs) - kwargs = dianna.utils.get_kwargs_applicable_to_function(experiments.explain_evaluate_images, kwargs) - experiments.explain_evaluate_images(out, data, method, grid, **kwargs) + method = kwargs.pop('method') + grid = ParamGrid(globals()[method[-4:] + '_config'].__dict__) + + experiments = Experiments(model, **kwargs) + kwargs = dianna.utils.get_kwargs_applicable_to_function(experiments.explain_evaluate_images, kwargs) + experiments.explain_evaluate_images(out, data, method, grid, **kwargs) if __name__ == '__main__': main() \ No newline at end of file diff --git a/relevance_maps_properties/experiments/runners.py b/relevance_maps_properties/experiments/runners.py index fd39fd5..2553be7 100644 --- a/relevance_maps_properties/experiments/runners.py +++ b/relevance_maps_properties/experiments/runners.py @@ -6,38 +6,39 @@ class ModelRunner: """Runs an onnx model with a set of inputs and outputs.""" - def __init__(self, - filename: Union[str, Path], - preprocess_function: Optional[Callable] = None + def __init__(self, + filename: Union[str, Path], + preprocess_function: Optional[Callable] = None, + device: int = 0 ): """Generates function to run ONNX model with one set of inputs and outputs. - Args: filename: Path to ONNX model on disk preprocess_function: Function to preprocess input data with - Returns: function - Examples: >>> runner = SimpleModelRunner('path_to_model.onnx') >>> predictions = runner(input_data) """ self.filename = filename self.preprocess_function = preprocess_function + self.device = device - def __call__(self, input_data, device=0): - """Get ONNX predictions.""" EP_list = self._set_EP(device) - sess = ort.InferenceSession(self.filename, providers=EP_list) - input_name = sess.get_inputs()[0].name - output_name = sess.get_outputs()[0].name + self.sess = ort.InferenceSession(self.filename, providers=EP_list) + + def __call__(self, input_data, device=0): + """Get ONNX predictions.""" + + input_name = self.sess.get_inputs()[0].name + output_name = self.sess.get_outputs()[0].name if self.preprocess_function is not None: input_data = self.preprocess_function(input_data) onnx_input = {input_name: input_data} - pred_onnx = sess.run([output_name], onnx_input)[0] + pred_onnx = self.sess.run([output_name], onnx_input)[0] return pred_onnx @staticmethod @@ -45,12 +46,10 @@ def _set_EP(device: int) -> list: if device == 0: # CPU return ['CPUExecutionProvider'] elif device == 1: # GPU - return [ - ("CUDAExecutionProvider", + return [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"} ), - "TensorRTExecutionProvider", - " CPUExecutionProvider" + "CPUExecutionProvider" ] else: - raise ValueError('Device has to be 0 (CPU) or 1 (GPU).') + raise ValueError('Device has to be 0 (CPU) or 1 (GPU).') \ No newline at end of file diff --git a/relevance_maps_properties/metrics/metrics.py b/relevance_maps_properties/metrics/metrics.py index 2a0bfb6..cd77f4e 100644 --- a/relevance_maps_properties/metrics/metrics.py +++ b/relevance_maps_properties/metrics/metrics.py @@ -56,7 +56,6 @@ def __call__(self, if not salient_batch.ndim == 3: raise ValueError(f'Salient batch has wrong dimenions, expected ndim=3, \ got ndim={salient_batch.ndim}.') - results = defaultdict(list) for salience_map in salient_batch: @@ -65,16 +64,19 @@ def __call__(self, impute_method, **model_kwargs) x = np.arange(salient_scores.size) / salient_scores.size salient_auc = auc(x, salient_scores) - results['salient_scores'].append(salient_scores) results['salient_auc'].append(salient_auc) + results['salient_scores'].append(salient_scores.tolist()) if evaluate_random_baseline: + if random_seed is not None: + np.random.seed(random_seed) for _ in range(salient_batch.shape[0]): random_order = self.get_random_order(input_img.shape[:2], random_seed) random_scores = self.evaluate(input_img, random_order, batch_size, impute_method, **model_kwargs) - results['random_scores'].append(random_scores) - results['random_auc'].append(auc(x, random_scores)) + random_auc = auc(x, random_scores) + results['random_auc'].append(random_auc) + results['random_scores'].append(random_scores.tolist()) return results def evaluate(self, @@ -197,20 +199,22 @@ def _make_impute(self, input_img: NDArray, @staticmethod def get_salient_order(salience_map: NDArray) -> NDArray: '''Return the order of relvances in terms of indices of `salience_map` - + + NOTE: Mergesort is necessary for sorting stability, i.e. + deleting pixels in proximity is desirable when neighbouring scores have + the same values. Args: salience_map: map of salient scores Returns: Indices of `salience_map` sorted by their value. ''' - return np.stack(np.unravel_index(np.argsort(salience_map, axis=None), + return np.stack(np.unravel_index(np.argsort(salience_map, axis=None, + kind='mergesort'), salience_map.shape), axis=-1)[::-1] @staticmethod def get_random_order(image_shape: tuple, random_seed: Optional[int] = 0) -> NDArray: '''Get a random order of coordinates ''' - if isinstance(random_seed, int): - np.random.seed(random_seed) indices = np.argwhere(np.ones(image_shape)) # Hack to get all cartesian coordinates np.random.shuffle(indices) diff --git a/relevance_maps_properties/metrics/utils.py b/relevance_maps_properties/metrics/utils.py index cf82c53..d7890c5 100644 --- a/relevance_maps_properties/metrics/utils.py +++ b/relevance_maps_properties/metrics/utils.py @@ -2,8 +2,10 @@ import numpy as np +from dianna.methods.lime import LIMEImage from onnx.onnx_ml_pb2 import ModelProto from numpy.typing import NDArray +from typing import Union def get_onnx_names(onnx_model: ModelProto) -> tuple: @@ -29,11 +31,14 @@ def LIME_postprocess(*args, **kwargs) -> NDArray: DIANNA yields: list[NDArray[(Any, Any), Any]] Quantus expects: NDArray((Any, Any, Any), Any) ''' - results = dianna.explain_image(*args, method='LIME', **kwargs) + init_kwargs = dianna.utils.get_kwargs_applicable_to_function(LIMEImage.__init__, + kwargs) + explainer = LIMEImage(**init_kwargs) + results = explainer.explain(*args, method='LIME', return_masks=False, **kwargs) return np.array(results) -def SHAP_postprocess(label, *args, **kwargs) -> NDArray: +def SHAP_postprocess(*args, **kwargs) -> NDArray: ''' Post-process the output of DIANNA KernelSHAP in according to what Quantus expects. @@ -41,7 +46,7 @@ def SHAP_postprocess(label, *args, **kwargs) -> NDArray: Quantus expects: NDArray((Any, Any, Any), Any) ''' shapley_values, segments_slic = dianna.explain_image(method='KernelSHAP', *args, **kwargs) - saliences = list(_fill_segmentation(shapley_values[label][0], segments_slic)) + saliences = list(_fill_segmentation(shapley_values[0], segments_slic)) return np.array(saliences)[np.newaxis, ..., np.newaxis] def RISE_postprocess(*args, **kwargs) -> NDArray: