diff --git a/relevance_maps_properties/metrics/metrics.py b/relevance_maps_properties/metrics/metrics.py index cd77f4e..ef36721 100644 --- a/relevance_maps_properties/metrics/metrics.py +++ b/relevance_maps_properties/metrics/metrics.py @@ -10,7 +10,7 @@ from numpy.typing import NDArray from tqdm import tqdm from sklearn.metrics import auc -from scipy.stats import mode +from scipy.stats import mode, pearsonr, spearmanr from copy import copy from PIL.Image import Image from torchtext.vocab import Vectors @@ -297,10 +297,44 @@ def __init__(self, model: Union[Callable, str], self.tokenizer = get_function(tokenizer, preprocess_function=None) self.vocab = Vectors(word_vectors, cache=os.path.dirname(word_vectors)) self.max_filter_size = max_filter_size - self.max_filter_size = max_filter_size self.pad_token = pad_token self.unk_token = unk_token + def __call__(self, + input_text: str, + salient_batch: list[list[tuple[str, str, float]]], + normalise: bool = False, + normalise_fn: Optional[Callable] = None, + impute_value: str = '', + p_thresh: float = .05 + ): + + results = defaultdict(list) + for salience_map in salient_batch: + scores, init_score = self.evaluate(salience_map, input_text, + impute_value = impute_value) + scores = init_score - scores + _, _, relevances = self.sort_salience_map(salience_map) + if normalise: + relevances = normalise_fn(np.array(relevances)) + + pearson = pearsonr(relevances, scores, alternative='greater') + spearman = spearmanr(relevances, scores, alternative='greater') + if pearson[1] > p_thresh: + continue + if spearman[1] > p_thresh: + continue + results['scores'].append(scores.tolist()) + results['pearson'].append(pearson[0]) + results['spearman'].append(spearman[0]) + + if not results['scores']: + raise RuntimeError('''Couldn\'t find reliable correlation estimates in given + batch of explanations. This is likely due to a poor choice + of hyperparameters. Please recompute the explanations.''') + results['init_score'] = float(init_score) + return results + def evaluate(self, salience_map: list[tuple[str, str, float]], input_sentence: str, @@ -320,35 +354,26 @@ def evaluate(self, Perturbed sentence scores and initial sentence score ''' # Tokenize setence. - # Tokenize setence. tokenized = self._preprocess_sentence(input_sentence) eval_sentence = copy(tokenized) _, indices, _ = self.sort_salience_map(salience_map) - _, indices, _ = self.sort_salience_map(salience_map) # Get original sentence score. init_pred = self.model([eval_sentence], **model_kwargs) init_score = init_pred.max() init_lbl = init_pred.argmax() - # Get original sentence score. - init_pred = self.model([eval_sentence], **model_kwargs) - init_score = init_pred.max() - init_lbl = init_pred.argmax() - + impute_value = self.vocab.stoi[impute_value] scores = np.empty(len(salience_map)) for i, token_idx in enumerate(indices): - # Perturb sentence and score model. # Perturb sentence and score model. tmp = eval_sentence[token_idx] eval_sentence[token_idx] = impute_value score = self.model([eval_sentence], **model_kwargs).flatten()[init_lbl] - score = self.model([eval_sentence], **model_kwargs).flatten()[init_lbl] eval_sentence[token_idx] = tmp scores[i] = score return scores, init_score - return scores, init_score def _preprocess_sentence(self, input_sentence: str) -> list: '''Tokenize and embed sentence. @@ -360,8 +385,6 @@ def _preprocess_sentence(self, input_sentence: str) -> list: tokens = self.tokenizer(input_sentence) if len(tokens) < self.max_filter_size: tokens += [self.pad_token] * (self.max_filter_size - len(tokens)) - if len(tokens) < self.max_filter_size: - tokens += [self.pad_token] * (self.max_filter_size - len(tokens)) embedded = [self.vocab.stoi[token] if token in self.vocab.stoi else self.vocab.stoi[self.unk_token] for token in tokens] @@ -386,7 +409,6 @@ def visualize(self, salience_map: tuple[str, str, float], ''' assert len(scores) >= len(salience_map) words, indices, relevances = self.sort_salience_map(salience_map) - words, indices, relevances = self.sort_salience_map(salience_map) fig, ax1 = plt.subplots() diff --git a/relevance_maps_properties/metrics/sensitivity.py b/relevance_maps_properties/metrics/sensitivity.py new file mode 100644 index 0000000..77f2297 --- /dev/null +++ b/relevance_maps_properties/metrics/sensitivity.py @@ -0,0 +1,377 @@ +"""This module contains the implementation of the Avg-Sensitivity metric.""" + +# This file is part of Quantus. +# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. +# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. +# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . +# Quantus project URL: . + + + +from typing import Any, Callable, Dict, List, Optional +import numpy as np + +from quantus.helpers import asserts +from quantus.functions import norm_func +from quantus.helpers import warn +from quantus.helpers.model.model_interface import ModelInterface +from quantus.functions.normalise_func import normalise_by_max +from quantus.functions.perturb_func import uniform_noise, perturb_batch +from quantus.functions.similarity_func import difference +from quantus.metrics.base_batched import BatchedPerturbationMetric + + + +import quantus +import dianna + + +class Sensitivity(BatchedPerturbationMetric): + """ + Implementation of Avg-Sensitivity by Yeh at el., 2019. + + Using Monte Carlo sampling-based approximation while measuring how explanations + change under slight perturbation - the average sensitivity is captured. + + References: + 1) Chih-Kuan Yeh et al. "On the (in) fidelity and sensitivity for explanations." + NeurIPS (2019): 10965-10976. + 2) Umang Bhatt et al.: "Evaluating and aggregating + feature-based model explanations." IJCAI (2020): 3016-3022. + """ + + def __init__( + self, + get_embeddings: Callable, + similarity_func: Optional[Callable] = None, + norm_numerator: Optional[Callable] = None, + norm_denominator: Optional[Callable] = None, + nr_samples: int = 200, + abs: bool = False, + normalise: bool = False, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func_kwargs: Optional[Dict[str, Any]] = None, + perturb_func: Callable = None, + lower_bound: float = 0.2, + upper_bound: Optional[float] = None, + perturb_func_kwargs: Optional[Dict[str, Any]] = None, + return_aggregate: bool = False, + aggregate_func: Callable = np.mean, + default_plot_func: Optional[Callable] = None, + disable_warnings: bool = False, + display_progressbar: bool = False, + return_nan_when_prediction_changes: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + similarity_func: callable + Similarity function applied to compare input and perturbed input. + If None, the default value is used, default=difference. + norm_numerator: callable + Function for norm calculations on the numerator. + If None, the default value is used, default=fro_norm + norm_denominator: callable + Function for norm calculations on the denominator. + If None, the default value is used, default=fro_norm + nr_samples: integer + The number of samples iterated, default=200. + normalise: boolean + Indicates whether normalise operation is applied on the attribution, default=True. + normalise_func: callable + Attribution normalisation function applied in case normalise=True. + If normalise_func=None, the default value is used, default=normalise_by_max. + normalise_func_kwargs: dict + Keyword arguments to be passed to normalise_func on call, default={}. + perturb_func: callable + Input perturbation function. If None, the default value is used, + default=gaussian_noise. + perturb_std: float + The amount of noise added, default=0.1. + perturb_mean: float + The mean of noise added, default=0.0. + perturb_func_kwargs: dict + Keyword arguments to be passed to perturb_func, default={}. + return_aggregate: boolean + Indicates if an aggregated score should be computed over all instances. + aggregate_func: callable + Callable that aggregates the scores given an evaluation call. + default_plot_func: callable + Callable that plots the metrics result. + disable_warnings: boolean + Indicates whether the warnings are printed, default=False. + display_progressbar: boolean + Indicates whether a tqdm-progress-bar is printed, default=False. + return_nan_when_prediction_changes: boolean + When set to true, the metric will be evaluated to NaN if the prediction changes after the perturbation is applied. + kwargs: optional + Keyword arguments. + """ + self.get_embeddings = get_embeddings + + if normalise_func is None: + normalise_func = normalise_by_max + + if perturb_func is None: + perturb_func = uniform_noise + + if perturb_func_kwargs is None: + perturb_func_kwargs = {} + perturb_func_kwargs["lower_bound"] = lower_bound + perturb_func_kwargs["upper_bound"] = upper_bound + + super().__init__( + abs=abs, + normalise=normalise, + normalise_func=normalise_func, + normalise_func_kwargs=normalise_func_kwargs, + perturb_func=perturb_func, + perturb_func_kwargs=perturb_func_kwargs, + return_aggregate=return_aggregate, + aggregate_func=aggregate_func, + default_plot_func=default_plot_func, + display_progressbar=display_progressbar, + disable_warnings=disable_warnings, + **kwargs, + ) + + # Save metric-specific attributes. + self.nr_samples = nr_samples + + if similarity_func is None: + similarity_func = difference + self.similarity_func = similarity_func + + if norm_numerator is None: + norm_numerator = norm_func.fro_norm + self.norm_numerator = norm_numerator + + if norm_denominator is None: + norm_denominator = norm_func.fro_norm + self.norm_denominator = norm_denominator + self.return_nan_when_prediction_changes = return_nan_when_prediction_changes + + # Asserts and warnings. + if not self.disable_warnings: + warn.warn_parameterisation( + metric_name=self.__class__.__name__, + sensitive_params=( + "amount of noise added 'lower_bound' and 'upper_bound', the number of samples " + "iterated over 'nr_samples', the function to perturb the input " + "'perturb_func', the similarity metric 'similarity_func' as well as " + "norm calculations on the numerator and denominator of the sensitivity" + " equation i.e., 'norm_numerator' and 'norm_denominator'" + ), + citation=( + "Yeh, Chih-Kuan, et al. 'On the (in) fidelity and sensitivity for explanations" + ".' arXiv preprint arXiv:1901.09392 (2019)" + ), + ) + warn.warn_noise_zero(noise=lower_bound) + + def __call__( + self, + model, + x_batch: np.array, + y_batch: np.array, + a_batch: Optional[np.ndarray] = None, + s_batch: Optional[np.ndarray] = None, + channel_first: Optional[bool] = None, + explain_func: Optional[Callable] = None, + explain_func_kwargs: Optional[Dict] = None, + model_predict_kwargs: Optional[Dict] = None, + softmax: Optional[bool] = False, + device: Optional[str] = None, + batch_size: int = 64, + custom_batch: Optional[Any] = None, + **kwargs, + ) -> List[float]: + """ + This implementation represents the main logic of the metric and makes the class object callable. + It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), + output labels (y_batch) and a torch or tensorflow model (model). + + Calls general_preprocess() with all relevant arguments, calls + () on each instance, and saves results to last_results. + Calls custom_postprocess() afterwards. Finally returns last_results. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + channel_first: boolean, optional + Indicates of the image dimensions are channel first, or channel last. + Inferred from the input shape if None. + explain_func: callable + Callable generating attributions. + explain_func_kwargs: dict, optional + Keyword arguments to be passed to explain_func on call. + model_predict_kwargs: dict, optional + Keyword arguments to be passed to the model's predict method. + softmax: boolean + Indicates whether to use softmax probabilities or logits in model prediction. + This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. + device: string + Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". + kwargs: optional + Keyword arguments. + + Returns + ------- + last_results: list + a list of Any with the evaluation scores of the concerned batch. + + Examples: + -------- + # Minimal imports. + >> import quantus + >> from quantus import LeNet + >> import torch + + # Enable GPU. + >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). + >> model = LeNet() + >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) + + # Load MNIST datasets and make loaders. + >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) + >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) + + # Load a batch of inputs and outputs to use for XAI evaluation. + >> x_batch, y_batch = iter(test_loader).next() + >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() + + # Generate Saliency attributions of the test set batch of the test set. + >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) + >> a_batch_saliency = a_batch_saliency.cpu().numpy() + + # Initialise the metric and evaluate explanations by calling the metric instance. + >> metric = Metric(abs=True, normalise=False) + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + """ + return super().__call__( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + s_batch=s_batch, + custom_batch=None, + channel_first=channel_first, + explain_func=explain_func, + explain_func_kwargs=explain_func_kwargs, + softmax=softmax, + device=device, + model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, + **kwargs, + ) + + def evaluate_batch( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + s_batch: np.ndarray, + aggregate_result: bool = False + ) -> np.ndarray: + """ + Evaluates model and attributes on a single data batch and returns the batched evaluation result. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on an instance-basis. + y_batch: np.ndarray + The output to be evaluated on an instance-basis. + a_batch: np.ndarray + The explanation to be evaluated on an instance-basis. + s_batch: np.ndarray + The segmentation to be evaluated on an instance-basis. + + Returns + ------- + : np.ndarray + The batched evaluation results. + """ + batch_size = x_batch.shape[0] + similarities = np.zeros((batch_size, self.nr_samples)) * np.nan + + for step_id in range(self.nr_samples): + + # Perturb input. + x_perturbed = perturb_batch( + perturb_func=self.perturb_func, + indices=np.tile(np.arange(0, x_batch[0].size), (batch_size, 1)), + indexed_axes=np.arange(0, x_batch[0].ndim), + arr=x_batch, + **self.perturb_func_kwargs, + ) + + changed_prediction_indices = ( + np.argwhere( + model.predict(x_batch).argmax(axis=-1) + != model.predict(x_perturbed).argmax(axis=-1) + ).reshape(-1) + if self.return_nan_when_prediction_changes + else [] + ) + + x_input = model.shape_input( + x=x_perturbed, + shape=x_batch.shape, + channel_first=True, + batched=True, + ) + + # Generate explanation based on perturbed input x. + a_perturbed = self.explain_func( + model=model.get_model(), + inputs=x_input, + targets=y_batch, + **self.explain_func_kwargs, + ) + + if self.normalise: + a_perturbed = self.normalise_func( + a_perturbed, + **self.normalise_func_kwargs, + ) + + if self.abs: + a_perturbed = np.abs(a_perturbed) + + # Measure similarity for each instance separately. + for instance_id in range(batch_size): + + if ( + self.return_nan_when_prediction_changes + and instance_id in changed_prediction_indices + ): + similarities[instance_id, step_id] = np.nan + continue + + sensitivities = self.similarity_func( + a=a_batch[instance_id].flatten(), + b=a_perturbed[instance_id].flatten(), + ) + numerator = self.norm_numerator(a=sensitivities) + sensitivities_norm = numerator + similarities[instance_id, step_id] = sensitivities_norm + if aggregate_result: + mean_func = np.mean if self.return_nan_when_prediction_changes else np.nanmean + return mean_func(similarities, axis=1) + return similarities.tolist() \ No newline at end of file