diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f2d6464..d3cee218 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,13 +18,14 @@ * Refactor OpenVINO imports by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/45 * Support OV IR / ONNX model file for Explainer by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/47 * Try CNN -> ViT assumption for IR insertion by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/48 -* Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models by @negvet in https://github.com/openvinotoolkit/openvino_xai/pull/49 +* Enable AISE for classification: Adaptive Input Sampling for Explanation of Black-box Models by @negvet in https://github.com/openvinotoolkit/openvino_xai/pull/49 * Upgrade OpenVINO to 2024.3.0 by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/52 * Add saliency map visualization with explanation.plot() by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/53 * Enable flexible naming for saved saliency maps and include confidence scores by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/51 * Add [Pointing Game](https://link.springer.com/article/10.1007/s11263-017-1059-x) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/54 * Add [Insertion-Deletion AUC](https://arxiv.org/abs/1806.07421) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/56 * Add [ADCC](https://arxiv.org/abs/2104.10252) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/57 +* Enable AISE for detection: Adaptive Input Sampling for Explanation of Black-box Models by @negvet in https://github.com/openvinotoolkit/openvino_xai/pull/55 ### Known Issues diff --git a/README.md b/README.md index db0b9fd3..2805da9d 100644 --- a/README.md +++ b/README.md @@ -71,10 +71,11 @@ At the moment, *Image Classification* and *Object Detection* tasks are supported |-----------------|----------------------|-----------|---------------------|-------| | Computer Vision | Image Classification | White-Box | ReciproCAM | [arxiv](https://arxiv.org/abs/2209.14074) / [src](openvino_xai/methods/white_box/recipro_cam.py) | | | | | VITReciproCAM | [arxiv](https://arxiv.org/abs/2310.02588) / [src](openvino_xai/methods/white_box/recipro_cam.py) | -| | | | ActivationMap | experimental / [src](openvino_xai/methods/white_box/activation_map.py) | -| | | Black-Box | AISE | [src](openvino_xai/methods/black_box/aise.py) | -| | | | RISE | [arxiv](https://arxiv.org/abs/1806.07421v3) / [src](openvino_xai/methods/black_box/rise.py) | -| | Object Detection | White-Box | ClassProbabilityMap | experimental / [src](openvino_xai/methods/white_box/det_class_probability_map.py) | +| | | | ActivationMap | experimental / [src](openvino_xai/methods/white_box/activation_map.py) | +| | | Black-Box | AISEClassification | [src](openvino_xai/methods/black_box/aise.py) | +| | | | RISE | [arxiv](https://arxiv.org/abs/1806.07421v3) / [src](openvino_xai/methods/black_box/rise.py) | +| | Object Detection | White-Box | ClassProbabilityMap | experimental / [src](openvino_xai/methods/white_box/det_class_probability_map.py) | +| | | Black-Box | AISEDetection | [src](openvino_xai/methods/black_box/aise.py) | ### Supported explainable models diff --git a/docs/source/user-guide.md b/docs/source/user-guide.md index 8916e046..284cf7aa 100644 --- a/docs/source/user-guide.md +++ b/docs/source/user-guide.md @@ -252,7 +252,7 @@ explanation.save("output_path", "name_") Black-box mode does not update the model (treating the model as a black box). Black-box approaches are based on the perturbation of the input data and measurement of the model's output change. -For black-box mode we support 2 algorithms: **AISE** (by default) and [**RISE**](https://arxiv.org/abs/1806.07421). AISE is more effective for generating saliency maps for a few specific classes. RISE - to generate maps for all classes at once. +For black-box mode we support 2 algorithms: **AISE** (by default) and [**RISE**](https://arxiv.org/abs/1806.07421). AISE is more effective for generating saliency maps for a few specific classes. RISE - to generate maps for all classes at once. AISE is supported for both classification and detection task. Pros: - **Flexible** - can be applied to any custom model. diff --git a/examples/run_detection.py b/examples/run_detection.py index da2533dc..f123e2cc 100644 --- a/examples/run_detection.py +++ b/examples/run_detection.py @@ -12,6 +12,7 @@ import openvino_xai as xai from openvino_xai.common.utils import logger from openvino_xai.explainer.explainer import ExplainMode +from openvino_xai.methods.black_box.base import Preset def get_argument_parser(): @@ -31,20 +32,22 @@ def preprocess_fn(x: np.ndarray) -> np.ndarray: return x -def main(argv): +def postprocess_fn(x) -> np.ndarray: + """Returns boxes, scores, labels.""" + return x["boxes"][0][:, :4], x["boxes"][0][:, 4], x["labels"][0] + + +def explain_white_box(args): """ White-box scenario. - Insertion of the XAI branch into the Model API wrapper, thus Model API wrapper has additional 'saliency_map' output. + Insertion of the XAI branch into the model, thus model has additional 'saliency_map' output. """ - parser = get_argument_parser() - args = parser.parse_args(argv) - # Create ov.Model model: ov.Model model = ov.Core().read_model(args.model_path) - # OTX YOLOX + # # OTX YOLOX # cls_head_output_node_names = [ # "/bbox_head/multi_level_conv_cls.0/Conv/WithoutBiases", # "/bbox_head/multi_level_conv_cls.1/Conv/WithoutBiases", @@ -75,6 +78,7 @@ def main(argv): explanation = explainer( image, targets=[0, 1, 2], # target classes to explain + overlay=True, ) logger.info( @@ -88,5 +92,53 @@ def main(argv): explanation.save(output, Path(args.image_path).stem) +def explain_black_box(args): + """ + Black-box scenario. + """ + + # Create ov.Model + model: ov.Model + model = ov.Core().read_model(args.model_path) + + # Create explainer object + explainer = xai.Explainer( + model=model, + task=xai.Task.DETECTION, + preprocess_fn=preprocess_fn, + postprocess_fn=postprocess_fn, + explain_mode=ExplainMode.BLACKBOX, # defaults to AUTO + ) + + # Prepare input image and explanation parameters, can be different for each explain call + image = cv2.imread(args.image_path) + + # Generate explanation + explanation = explainer( + image, + targets=[0], # target boxes to explain + overlay=True, + preset=Preset.SPEED, + ) + + logger.info( + f"Generated {len(explanation.saliency_map)} detection " + f"saliency maps of layout {explanation.layout} with shape {explanation.shape}." + ) + + # Save saliency maps for visual inspection + if args.output is not None: + output = Path(args.output) / "detection_black_box" + explanation.save(output, f"{Path(args.image_path).stem}_") + + +def main(argv): + parser = get_argument_parser() + args = parser.parse_args(argv) + + explain_white_box(args) + explain_black_box(args) + + if __name__ == "__main__": main(sys.argv[1:]) diff --git a/openvino_xai/explainer/explainer.py b/openvino_xai/explainer/explainer.py index 4b7cfd22..31c08afd 100644 --- a/openvino_xai/explainer/explainer.py +++ b/openvino_xai/explainer/explainer.py @@ -222,6 +222,7 @@ def explain( saliency_map=saliency_map, targets=targets, label_names=label_names, + metadata=self.method.metadata, ) return self._visualize( original_input_image, diff --git a/openvino_xai/explainer/explanation.py b/openvino_xai/explainer/explanation.py index e4b5f929..13f61ef5 100644 --- a/openvino_xai/explainer/explanation.py +++ b/openvino_xai/explainer/explanation.py @@ -4,12 +4,13 @@ import os from enum import Enum from pathlib import Path -from typing import Dict, List +from typing import Any, Dict, List import cv2 import matplotlib.pyplot as plt import numpy as np +from openvino_xai.common.parameters import Task from openvino_xai.common.utils import logger from openvino_xai.explainer.utils import ( convert_targets_to_numpy, @@ -36,6 +37,7 @@ def __init__( saliency_map: np.ndarray | Dict[int | str, np.ndarray], targets: np.ndarray | List[int | str] | int | str, label_names: List[str] | None = None, + metadata: Dict[Task, Any] | None = None, ): targets = convert_targets_to_numpy(targets) @@ -58,6 +60,7 @@ def __init__( self._saliency_map = self._select_target_saliency_maps(targets, label_names) self.label_names = label_names + self.metadata = metadata @property def saliency_map(self) -> Dict[int | str, np.ndarray]: diff --git a/openvino_xai/explainer/visualizer.py b/openvino_xai/explainer/visualizer.py index 3c7dda3a..aedd8295 100644 --- a/openvino_xai/explainer/visualizer.py +++ b/openvino_xai/explainer/visualizer.py @@ -1,11 +1,12 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple import cv2 import numpy as np +from openvino_xai.common.parameters import Task from openvino_xai.common.utils import format_to_bhwc, infer_size_from_image, scaling from openvino_xai.explainer.explanation import ( COLOR_MAPPED_LAYOUTS, @@ -66,7 +67,7 @@ def overlay( class Visualizer: """ - Visualizer implements post-processing for the saliency map in explanation result. + Visualizer implements post-processing for the saliency maps in explanation. """ def __call__( @@ -130,7 +131,7 @@ def visualize( original_input_image = format_to_bhwc(original_input_image) saliency_map_dict = explanation.saliency_map - class_idx_to_return = list(saliency_map_dict.keys()) + indices_to_return = list(saliency_map_dict.keys()) # Convert to numpy array to use vectorized scale (0 ~ 255) operation and speed up lots of classes scenario saliency_map_np = np.array(list(saliency_map_dict.values())) @@ -146,6 +147,7 @@ def visualize( saliency_map_np = self._apply_overlay( explanation, saliency_map_np, original_input_image, output_size, overlay_weight ) + saliency_map_np = self._apply_metadata(explanation.metadata, saliency_map_np, indices_to_return) else: if resize: if original_input_image is None and output_size is None: @@ -157,7 +159,30 @@ def visualize( saliency_map_np = self._apply_colormap(explanation, saliency_map_np) # Convert back to dict - return self._update_explanation_with_processed_sal_map(explanation, saliency_map_np, class_idx_to_return) + return self._update_explanation_with_processed_sal_map(explanation, saliency_map_np, indices_to_return) + + @staticmethod + def _apply_metadata(metadata: Dict[Task, Any], saliency_map_np: np.ndarray, indices: List[int | str]): + # TODO (negvet): support when indices are strings + if metadata: + if Task.DETECTION in metadata: + for smap_i, target_index in zip(range(len(saliency_map_np)), indices): + saliency_map = saliency_map_np[smap_i] + box, score, label_index = metadata[Task.DETECTION][target_index] + x1, y1, x2, y2 = box + cv2.rectangle(saliency_map, (int(x1), int(y1)), (int(x2), int(y2)), color=(255, 0, 0), thickness=2) + box_label = f"{label_index}|{score:.2f}" + box_label_loc = int(x1), int(y1 - 5) + cv2.putText( + saliency_map, + box_label, + org=box_label_loc, + fontFace=1, + fontScale=1, + color=(255, 0, 0), + thickness=2, + ) + return saliency_map_np @staticmethod def _apply_scaling(explanation: Explanation, saliency_map_np: np.ndarray) -> np.ndarray: @@ -222,15 +247,15 @@ def _apply_overlay( def _update_explanation_with_processed_sal_map( explanation: Explanation, saliency_map_np: np.ndarray, - class_idx: List, + target_indices: List, ) -> Explanation: dict_sal_map: Dict[int | str, np.ndarray] = {} if explanation.layout in ONE_MAP_LAYOUTS: dict_sal_map["per_image_map"] = saliency_map_np[0] saliency_map_np = dict_sal_map elif explanation.layout in MULTIPLE_MAP_LAYOUTS: - for idx, class_sal in zip(class_idx, saliency_map_np): - dict_sal_map[idx] = class_sal + for index, sal_map in zip(target_indices, saliency_map_np): + dict_sal_map[index] = sal_map else: raise ValueError explanation.saliency_map = dict_sal_map diff --git a/openvino_xai/methods/__init__.py b/openvino_xai/methods/__init__.py index 58c3a114..08cce51a 100644 --- a/openvino_xai/methods/__init__.py +++ b/openvino_xai/methods/__init__.py @@ -3,7 +3,8 @@ """ XAI algorithms. """ -from openvino_xai.methods.black_box.aise import AISE +from openvino_xai.methods.black_box.aise.classification import AISEClassification +from openvino_xai.methods.black_box.aise.detection import AISEDetection from openvino_xai.methods.black_box.rise import RISE from openvino_xai.methods.white_box.activation_map import ActivationMap from openvino_xai.methods.white_box.base import WhiteBoxMethod @@ -24,5 +25,6 @@ "ViTReciproCAM", "DetClassProbabilityMap", "RISE", - "AISE", + "AISEClassification", + "AISEDetection", ] diff --git a/openvino_xai/methods/base.py b/openvino_xai/methods/base.py index 2673491d..59d207db 100644 --- a/openvino_xai/methods/base.py +++ b/openvino_xai/methods/base.py @@ -1,12 +1,14 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import collections from abc import ABC, abstractmethod -from typing import Callable, Dict, Mapping +from typing import Any, Callable, Dict, Mapping import numpy as np import openvino as ov +from openvino_xai.common.parameters import Task from openvino_xai.common.utils import IdentityPreprocessFN @@ -23,6 +25,7 @@ def __init__( self._model_compiled = None self.preprocess_fn = preprocess_fn self._device_name = device_name + self.metadata: Dict[Task, Any] = collections.defaultdict(dict) @property def model_compiled(self) -> ov.CompiledModel | None: diff --git a/openvino_xai/methods/black_box/aise/__init__.py b/openvino_xai/methods/black_box/aise/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/openvino_xai/methods/black_box/aise.py b/openvino_xai/methods/black_box/aise/base.py similarity index 50% rename from openvino_xai/methods/black_box/aise.py rename to openvino_xai/methods/black_box/aise/base.py index b2e27324..e1c98289 100644 --- a/openvino_xai/methods/black_box/aise.py +++ b/openvino_xai/methods/black_box/aise/base.py @@ -3,24 +3,20 @@ import collections import math +from abc import ABC, abstractmethod from typing import Callable, Dict, List, Mapping, Tuple import numpy as np import openvino.runtime as ov -from scipy.optimize import Bounds, direct +from scipy.optimize import direct -from openvino_xai.common.utils import ( - IdentityPreprocessFN, - infer_size_from_image, - logger, - scaling, - sigmoid, -) -from openvino_xai.methods.black_box.base import BlackBoxXAIMethod, Preset +from openvino_xai.common.utils import IdentityPreprocessFN +from openvino_xai.methods.black_box.base import BlackBoxXAIMethod -class AISE(BlackBoxXAIMethod): - """AISE explains classification models in black-box mode using +class AISEBase(BlackBoxXAIMethod, ABC): + """ + AISE explains models in black-box mode using AISE: Adaptive Input Sampling for Explanation of Black-box Models (TODO (negvet): add link to the paper.) @@ -59,98 +55,13 @@ def __init__( self.pred_score_hist: Dict = collections.defaultdict(list) self.input_size: Tuple[int, int] | None = None self._mask_generator: GaussianPerturbationMask | None = None + self.bounds = None + self.preservation = True + self.deletion = True if prepare_model: self.prepare_model() - def generate_saliency_map( # type: ignore - self, - data: np.ndarray, - target_indices: List[int] | None, - preset: Preset = Preset.BALANCE, - num_iterations_per_kernel: int | None = None, - kernel_widths: List[float] | np.ndarray | None = None, - solver_epsilon: float = 0.1, - locally_biased: bool = False, - scale_output: bool = True, - ) -> Dict[int, np.ndarray]: - """ - Generates inference result of the AISE algorithm. - Optimized for per class saliency map generation. Not effcient for large number of classes. - - :param data: Input image. - :type data: np.ndarray - :param target_indices: List of target indices to explain. - :type target_indices: List[int] - :param preset: Speed-Quality preset, defines predefined configurations that manage the speed-quality tradeoff. - :type preset: Preset - :param num_iterations_per_kernel: Number of iterations per kernel, defines compute budget. - :type num_iterations_per_kernel: int - :param kernel_widths: Kernel bandwidths. - :type kernel_widths: List[float] | np.ndarray - :param solver_epsilon: Solver epsilon of DIRECT optimizer. - :type solver_epsilon: float - :param locally_biased: Locally biased flag of DIRECT optimizer. - :type locally_biased: bool - :param scale_output: Whether to scale output or not. - :type scale_output: bool - """ - self.data_preprocessed = self.preprocess_fn(data) - - if target_indices is None: - num_classes = self.get_num_classes(self.data_preprocessed) - if num_classes > 10: - logger.info(f"num_classes = {num_classes}, which might take significant time to process.") - target_indices = list(range(num_classes)) - - self.num_iterations_per_kernel, self.kernel_widths = self._preset_parameters( - preset, - num_iterations_per_kernel, - kernel_widths, - ) - - self.solver_epsilon = solver_epsilon - self.locally_biased = locally_biased - - self.input_size = infer_size_from_image(self.data_preprocessed) - self._mask_generator = GaussianPerturbationMask(self.input_size) - - saliency_maps = {} - for target in target_indices: - self.kernel_params_hist = collections.defaultdict(list) - self.pred_score_hist = collections.defaultdict(list) - - self.target = target - saliency_map_per_target = self._run_synchronous_explanation() - if scale_output: - saliency_map_per_target = scaling(saliency_map_per_target) - saliency_maps[target] = saliency_map_per_target - return saliency_maps - - @staticmethod - def _preset_parameters( - preset: Preset, - num_iterations_per_kernel: int | None, - kernel_widths: List[float] | np.ndarray | None, - ) -> Tuple[int, np.ndarray]: - if preset == Preset.SPEED: - iterations = 25 - widths = np.linspace(0.1, 0.25, 3) - elif preset == Preset.BALANCE: - iterations = 50 - widths = np.linspace(0.1, 0.25, 3) - elif preset == Preset.QUALITY: - iterations = 85 - widths = np.linspace(0.075, 0.25, 4) - else: - raise ValueError(f"Preset {preset} is not supported.") - - if num_iterations_per_kernel is None: - num_iterations_per_kernel = iterations - if kernel_widths is None: - kernel_widths = widths - return num_iterations_per_kernel, kernel_widths - def _run_synchronous_explanation(self) -> np.ndarray: for kernel_width in self.kernel_widths: self._current_kernel_width = kernel_width @@ -161,7 +72,7 @@ def _run_optimization(self): """Run DIRECT optimizer by default.""" _ = direct( func=self._objective_function, - bounds=Bounds([0.0, 0.0], [1.0, 1.0]), + bounds=self.bounds, eps=self.solver_epsilon, maxfun=self.num_iterations_per_kernel, locally_biased=self.locally_biased, @@ -170,7 +81,7 @@ def _run_optimization(self): def _objective_function(self, args) -> float: """ Objective function to optimize (to find a global minimum). - Hybrid (dual) paradigm adopted with two sub-objectives: + Hybrid (dual) paradigm supporte two sub-objectives: - preservation - deletion """ @@ -181,27 +92,26 @@ def _objective_function(self, args) -> float: kernel_mask = self._mask_generator.generate_kernel_mask(kernel_params) kernel_mask = np.clip(kernel_mask, 0, 1) - data_perturbed_preserve = self.data_preprocessed * kernel_mask - pred_score_preserve = self._get_score(data_perturbed_preserve) + pred_loss_preserve = 0.0 + if self.preservation: + data_perturbed_preserve = self.data_preprocessed * kernel_mask + pred_loss_preserve = self._get_loss(data_perturbed_preserve) - data_perturbed_delete = self.data_preprocessed * (1 - kernel_mask) - pred_score_delete = self._get_score(data_perturbed_delete) + pred_loss_delete = 0.0 + if self.deletion: + data_perturbed_delete = self.data_preprocessed * (1 - kernel_mask) + pred_loss_delete = self._get_loss(data_perturbed_delete) - loss = pred_score_preserve - pred_score_delete + loss = pred_loss_preserve - pred_loss_delete - self.pred_score_hist[self._current_kernel_width].append(pred_score_preserve - pred_score_delete) + self.pred_score_hist[self._current_kernel_width].append(pred_loss_preserve - pred_loss_delete) loss *= -1 # Objective: minimize return loss - def _get_score(self, data_perturbed: np.array) -> float: - """Get model prediction score for perturbed input.""" - x = self.model_forward(data_perturbed, preprocess=False) - x = self.postprocess_fn(x) - if np.max(x) > 1 or np.min(x) < 0: - x = sigmoid(x) - pred_scores = x.squeeze() # type: ignore - return pred_scores[self.target] + @abstractmethod + def _get_loss(self, data_perturbed: np.array) -> float: + pass def _kernel_density_estimation(self) -> np.ndarray: """Aggregate the result per kernel with KDE.""" @@ -213,7 +123,9 @@ def _kernel_density_estimation(self) -> np.ndarray: kernel_mask = self._mask_generator.generate_kernel_mask(kernel_params) score = self.pred_score_hist[kernel_width][i] kernel_masks_weighted += kernel_mask * score - kernel_masks_weighted = kernel_masks_weighted / kernel_masks_weighted.max() + kernel_masks_weighted_max = kernel_masks_weighted.max() + if kernel_masks_weighted_max > 0: + kernel_masks_weighted = kernel_masks_weighted / kernel_masks_weighted_max saliency_map_per_kernel[kernel_index] = kernel_masks_weighted saliency_map = saliency_map_per_kernel.sum(axis=0) diff --git a/openvino_xai/methods/black_box/aise/classification.py b/openvino_xai/methods/black_box/aise/classification.py new file mode 100644 index 00000000..3796877f --- /dev/null +++ b/openvino_xai/methods/black_box/aise/classification.py @@ -0,0 +1,154 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import collections +from typing import Callable, Dict, List, Tuple + +import numpy as np +import openvino.runtime as ov +from openvino.runtime.utils.data_helpers.wrappers import OVDict +from scipy.optimize import Bounds + +from openvino_xai.common.utils import ( + IdentityPreprocessFN, + infer_size_from_image, + logger, + scaling, + sigmoid, +) +from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask +from openvino_xai.methods.black_box.base import Preset + + +class AISEClassification(AISEBase): + """ + AISE for classification models. + + postprocess_fn expected to return one container with scores. Without batch dim. + + :param model: OpenVINO model. + :type model: ov.Model + :param postprocess_fn: Post-processing function that extract scores from IR model output. + :type postprocess_fn: Callable[[OVDict], np.ndarray] + :param preprocess_fn: Pre-processing function, identity function by default + (assume input images are already preprocessed by user). + :type preprocess_fn: Callable[[np.ndarray], np.ndarray] + :param device_name: Device type name. + :type device_name: str + :param prepare_model: Loading (compiling) the model prior to inference. + :type prepare_model: bool + """ + + def __init__( + self, + model: ov.Model, + postprocess_fn: Callable[[OVDict], np.ndarray], + preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), + device_name: str = "CPU", + prepare_model: bool = True, + ): + super().__init__( + model=model, + postprocess_fn=postprocess_fn, + preprocess_fn=preprocess_fn, + device_name=device_name, + prepare_model=prepare_model, + ) + self.bounds = Bounds([0.0, 0.0], [1.0, 1.0]) + + def generate_saliency_map( # type: ignore + self, + data: np.ndarray, + target_indices: List[int] | None, + preset: Preset = Preset.BALANCE, + num_iterations_per_kernel: int | None = None, + kernel_widths: List[float] | np.ndarray | None = None, + solver_epsilon: float = 0.1, + locally_biased: bool = False, + scale_output: bool = True, + ) -> Dict[int, np.ndarray]: + """ + Generates inference result of the AISE algorithm. + Optimized for per class saliency map generation. Not effcient for large number of classes. + + :param data: Input image. + :type data: np.ndarray + :param target_indices: List of target indices to explain. + :type target_indices: List[int] + :param preset: Speed-Quality preset, defines predefined configurations that manage the speed-quality tradeoff. + :type preset: Preset + :param num_iterations_per_kernel: Number of iterations per kernel, defines compute budget. + :type num_iterations_per_kernel: int + :param kernel_widths: Kernel bandwidths. + :type kernel_widths: List[float] | np.ndarray + :param solver_epsilon: Solver epsilon of DIRECT optimizer. + :type solver_epsilon: float + :param locally_biased: Locally biased flag of DIRECT optimizer. + :type locally_biased: bool + :param scale_output: Whether to scale output or not. + :type scale_output: bool + """ + self.data_preprocessed = self.preprocess_fn(data) + + if target_indices is None: + num_classes = self.get_num_classes(self.data_preprocessed) + if num_classes > 10: + logger.info(f"num_classes = {num_classes}, which might take significant time to process.") + target_indices = list(range(num_classes)) + + self.num_iterations_per_kernel, self.kernel_widths = self._preset_parameters( + preset, + num_iterations_per_kernel, + kernel_widths, + ) + + self.solver_epsilon = solver_epsilon + self.locally_biased = locally_biased + + self.input_size = infer_size_from_image(self.data_preprocessed) + self._mask_generator = GaussianPerturbationMask(self.input_size) + + saliency_maps = {} + for target in target_indices: + self.kernel_params_hist = collections.defaultdict(list) + self.pred_score_hist = collections.defaultdict(list) + + self.target = target + saliency_map_per_target = self._run_synchronous_explanation() + if scale_output: + saliency_map_per_target = scaling(saliency_map_per_target) + saliency_maps[target] = saliency_map_per_target + return saliency_maps + + @staticmethod + def _preset_parameters( + preset: Preset, + num_iterations_per_kernel: int | None, + kernel_widths: List[float] | np.ndarray | None, + ) -> Tuple[int, np.ndarray]: + if preset == Preset.SPEED: + iterations = 25 + widths = np.linspace(0.1, 0.25, 3) + elif preset == Preset.BALANCE: + iterations = 50 + widths = np.linspace(0.1, 0.25, 3) + elif preset == Preset.QUALITY: + iterations = 85 + widths = np.linspace(0.075, 0.25, 4) + else: + raise ValueError(f"Preset {preset} is not supported.") + + if num_iterations_per_kernel is None: + num_iterations_per_kernel = iterations + if kernel_widths is None: + kernel_widths = widths + return num_iterations_per_kernel, kernel_widths + + def _get_loss(self, data_perturbed: np.array) -> float: + """Get loss for perturbed input.""" + x = self.model_forward(data_perturbed, preprocess=False) + x = self.postprocess_fn(x) + if np.max(x) > 1 or np.min(x) < 0: + x = sigmoid(x) + pred_scores = x.squeeze() # type: ignore + return pred_scores[self.target] diff --git a/openvino_xai/methods/black_box/aise/detection.py b/openvino_xai/methods/black_box/aise/detection.py new file mode 100644 index 00000000..bac7c3f4 --- /dev/null +++ b/openvino_xai/methods/black_box/aise/detection.py @@ -0,0 +1,216 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import collections +from typing import Any, Callable, Dict, List, Tuple + +import numpy as np +import openvino.runtime as ov +from openvino.runtime.utils.data_helpers.wrappers import OVDict +from scipy.optimize import Bounds + +from openvino_xai.common.parameters import Task +from openvino_xai.common.utils import ( + IdentityPreprocessFN, + infer_size_from_image, + logger, + scaling, +) +from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask +from openvino_xai.methods.black_box.base import Preset + + +class AISEDetection(AISEBase): + """ + AISE for detection models. + + postprocess_fn expected to return three containers: boxes (format: [x1, y1, x2, y2]), scores, labels. Without batch dim. + + :param model: OpenVINO model. + :type model: ov.Model + :param postprocess_fn: Post-processing function that extract scores from IR model output. + :type postprocess_fn: Callable[[OVDict], np.ndarray] + :param preprocess_fn: Pre-processing function, identity function by default + (assume input images are already preprocessed by user). + :type preprocess_fn: Callable[[np.ndarray], np.ndarray] + :param device_name: Device type name. + :type device_name: str + :param prepare_model: Loading (compiling) the model prior to inference. + :type prepare_model: bool + """ + + def __init__( + self, + model: ov.Model, + postprocess_fn: Callable[[OVDict], np.ndarray], + preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), + device_name: str = "CPU", + prepare_model: bool = True, + ): + super().__init__( + model=model, + postprocess_fn=postprocess_fn, + preprocess_fn=preprocess_fn, + device_name=device_name, + prepare_model=prepare_model, + ) + self.deletion = False + + def generate_saliency_map( # type: ignore + self, + data: np.ndarray, + target_indices: List[int] | None, + preset: Preset = Preset.BALANCE, + num_iterations_per_kernel: int | None = None, + divisors: List[float] | np.ndarray | None = None, + solver_epsilon: float = 0.05, + locally_biased: bool = False, + scale_output: bool = True, + ) -> Dict[int, np.ndarray]: + """ + Generates inference result of the AISE algorithm. + Optimized for per class saliency map generation. Not effcient for large number of classes. + + :param data: Input image. + :type data: np.ndarray + :param target_indices: List of target indices to explain. + :type target_indices: List[int] + :param preset: Speed-Quality preset, defines predefined configurations that manage the speed-quality tradeoff. + :type preset: Preset + :param num_iterations_per_kernel: Number of iterations per kernel, defines compute budget. + :type num_iterations_per_kernel: int + :param divisors: List of dividors, used to derive kernel widths in an adaptive manner. + :type divisors: List[float] | np.ndarray + :param solver_epsilon: Solver epsilon of DIRECT optimizer. + :type solver_epsilon: float + :param locally_biased: Locally biased flag of DIRECT optimizer. + :type locally_biased: bool + :param scale_output: Whether to scale output or not. + :type scale_output: bool + """ + # TODO (negvet): support custom bboxes (not predicted ones) + + self.data_preprocessed = self.preprocess_fn(data) + forward_output = self.model_forward(self.data_preprocessed, preprocess=False) + + # postprocess_fn expected to return three containers: boxes (x1, y1, x2, y2), scores, labels, without batch dim. + boxes, scores, labels = self.postprocess_fn(forward_output) + + if target_indices is None: + num_boxes = len(boxes) + if num_boxes > 10: + logger.info(f"num_boxes = {num_boxes}, which might take significant time to process.") + target_indices = list(range(num_boxes)) + + self.num_iterations_per_kernel, self.divisors = self._preset_parameters( + preset, + num_iterations_per_kernel, + divisors, + ) + + self.solver_epsilon = solver_epsilon + self.locally_biased = locally_biased + + self.input_size = infer_size_from_image(self.data_preprocessed) + original_size = infer_size_from_image(data) + self._mask_generator = GaussianPerturbationMask(self.input_size) + + saliency_maps = {} + self.metadata: Dict[Task, Any] = collections.defaultdict(dict) + for target in target_indices: + self.target_box = boxes[target] + self.target_label = labels[target] + + if self.target_box[0] >= self.target_box[2] or self.target_box[1] >= self.target_box[3]: + continue + + self.kernel_params_hist = collections.defaultdict(list) + self.pred_score_hist = collections.defaultdict(list) + + self._process_box() + saliency_map_per_target = self._run_synchronous_explanation() + if scale_output: + saliency_map_per_target = scaling(saliency_map_per_target) + saliency_maps[target] = saliency_map_per_target + + self._update_metadata(boxes, scores, labels, target, original_size) + return saliency_maps + + @staticmethod + def _preset_parameters( + preset: Preset, + num_iterations_per_kernel: int | None, + divisors: List[float] | np.ndarray | None, + ) -> Tuple[int, np.ndarray]: + if preset == Preset.SPEED: + iterations = 50 + divs = np.linspace(7, 1, 3) + elif preset == Preset.BALANCE: + iterations = 100 + divs = np.linspace(7, 1, 3) + elif preset == Preset.QUALITY: + iterations = 150 + divs = np.linspace(8, 1, 5) + else: + raise ValueError(f"Preset {preset} is not supported.") + + if num_iterations_per_kernel is None: + num_iterations_per_kernel = iterations + if divisors is None: + divisors = divs + return num_iterations_per_kernel, divisors + + def _process_box(self, padding_coef: float = 0.5) -> None: + target_box_scaled = [ + self.target_box[0] / self.input_size[1], # x1 + self.target_box[1] / self.input_size[0], # y1 + self.target_box[2] / self.input_size[1], # x2 + self.target_box[3] / self.input_size[0], # y2 + ] + box_width = target_box_scaled[2] - target_box_scaled[0] + box_height = target_box_scaled[3] - target_box_scaled[1] + self._min_box_size = min(box_width, box_height) + self.kernel_widths = [self._min_box_size / div for div in self.divisors] + + x_from = max(target_box_scaled[0] - box_width * padding_coef, 0.0) + x_to = min(target_box_scaled[2] + box_width * padding_coef, 1.0) + y_from = max(target_box_scaled[1] - box_height * padding_coef, 0.0) + y_to = min(target_box_scaled[3] + box_height * padding_coef, 1.0) + self.bounds = Bounds([x_from, y_from], [x_to, y_to]) + + def _get_loss(self, data_perturbed: np.array) -> float: + """Get loss for perturbed input.""" + forward_output = self.model_forward(data_perturbed, preprocess=False) + boxes, pred_scores, labels = self.postprocess_fn(forward_output) + + loss = 0 + for box, pred_score, label in zip(boxes, pred_scores, labels): + if label == self.target_label: + loss = max(loss, self._iou(self.target_box, box) * pred_score) + return loss + + @staticmethod + def _iou(box1: np.ndarray | List[float], box2: np.ndarray | List[float]) -> float: + box1 = np.asarray(box1) + box2 = np.asarray(box2) + tl = np.vstack([box1[:2], box2[:2]]).max(axis=0) + br = np.vstack([box1[2:], box2[2:]]).min(axis=0) + intersection = np.prod(br - tl) * np.all(tl < br).astype(float) + area1 = np.prod(box1[2:] - box1[:2]) + area2 = np.prod(box2[2:] - box2[:2]) + return intersection / (area1 + area2 - intersection) + + def _update_metadata( + self, + boxes: np.ndarray | List, + scores: np.ndarray | List[float], + labels: np.ndarray | List[int], + target: int, + original_size: Tuple[int, int], + ) -> None: + x1, y1, x2, y2 = boxes[target] + width_scale = original_size[1] / self.input_size[1] + height_scale = original_size[0] / self.input_size[0] + x1, x2 = x1 * width_scale, x2 * width_scale + y1, y2 = y1 * height_scale, y2 * height_scale + self.metadata[Task.DETECTION][target] = [x1, y1, x2, y2], scores[target], labels[target] diff --git a/openvino_xai/methods/factory.py b/openvino_xai/methods/factory.py index 07c6680d..c3445ece 100644 --- a/openvino_xai/methods/factory.py +++ b/openvino_xai/methods/factory.py @@ -10,7 +10,8 @@ from openvino_xai.common.parameters import Method, Task from openvino_xai.common.utils import IdentityPreprocessFN, logger from openvino_xai.methods.base import MethodBase -from openvino_xai.methods.black_box.aise import AISE +from openvino_xai.methods.black_box.aise.classification import AISEClassification +from openvino_xai.methods.black_box.aise.detection import AISEDetection from openvino_xai.methods.black_box.base import BlackBoxXAIMethod from openvino_xai.methods.black_box.rise import RISE from openvino_xai.methods.white_box.activation_map import ActivationMap @@ -229,11 +230,18 @@ def create_classification_method( :type device_name: str """ if explain_method is None or explain_method == Method.AISE: - return AISE(model, postprocess_fn, preprocess_fn, device_name, **kwargs) + return AISEClassification(model, postprocess_fn, preprocess_fn, device_name, **kwargs) elif explain_method == Method.RISE: return RISE(model, postprocess_fn, preprocess_fn, device_name, **kwargs) raise ValueError(f"Requested explanation method {explain_method} is not implemented.") @staticmethod - def create_detection_method(*args, **kwargs) -> BlackBoxXAIMethod: - raise ValueError("Detection models are not supported in black-box mode yet.") + def create_detection_method( + model: ov.Model, + postprocess_fn: Callable[[Mapping], np.ndarray], + preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), + explain_method: Method | None = None, + device_name: str = "CPU", + **kwargs, + ) -> BlackBoxXAIMethod: + return AISEDetection(model, postprocess_fn, preprocess_fn, device_name, **kwargs) diff --git a/tests/intg/test_detection.py b/tests/intg/test_detection.py index 0f3ea397..96e9eb34 100644 --- a/tests/intg/test_detection.py +++ b/tests/intg/test_detection.py @@ -14,7 +14,8 @@ from openvino_xai.common.utils import retrieve_otx_model from openvino_xai.explainer.explainer import Explainer, ExplainMode from openvino_xai.explainer.utils import get_preprocess_fn -from openvino_xai.methods.factory import WhiteBoxMethodFactory +from openvino_xai.methods.black_box.aise.detection import AISEDetection +from openvino_xai.methods.factory import BlackBoxMethodFactory, WhiteBoxMethodFactory from openvino_xai.methods.white_box.det_class_probability_map import ( DetClassProbabilityMap, ) @@ -57,6 +58,7 @@ MODELS = list(MODEL_CONFIGS.keys()) DEFAULT_DET_MODEL = "det_mobilenetv2_atss_bccd" +FAST_DET_MODEL = "det_mobilenetv2_atss_bccd" EXPLAIN_ALL_CLASSES = [ True, @@ -66,11 +68,11 @@ class TestDetWB: """ - Tests detection models in WB mode. + Tests detection models in white-box mode. """ image = cv2.imread("tests/assets/blood.jpg") - _ref_sal_maps_reciprocam = { + _ref_sal_maps = { "det_mobilenetv2_atss_bccd": np.array([222, 243, 232, 229, 221, 217, 237, 246, 252, 255], dtype=np.uint8), "det_mobilenetv2_ssd_bccd": np.array([83, 93, 61, 48, 110, 109, 78, 128, 158, 111], dtype=np.uint8), "det_yolox_bccd": np.array([17, 13, 15, 60, 94, 52, 61, 47, 8, 40], dtype=np.uint8), @@ -120,7 +122,7 @@ def test_detclassprobabilitymap(self, model_name, embed_scaling, explain_all_cla assert explanation.saliency_map[0].shape == self._sal_map_size actual_sal_vals = explanation.saliency_map[0][0, :10].astype(np.int16) - ref_sal_vals = self._ref_sal_maps_reciprocam[model_name].astype(np.uint8) + ref_sal_vals = self._ref_sal_maps[model_name].astype(np.uint8) if embed_scaling: # Reference values generated with embed_scaling=True assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) @@ -198,7 +200,7 @@ def test_two_sequential_norms(self): ) actual_sal_vals = explanation.saliency_map[0][0, :10].astype(np.int16) - ref_sal_vals = self._ref_sal_maps_reciprocam[DEFAULT_DET_MODEL].astype(np.uint8) + ref_sal_vals = self._ref_sal_maps[DEFAULT_DET_MODEL].astype(np.uint8) # Reference values generated with embed_scaling=True assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) @@ -234,6 +236,108 @@ def get_default_model(self): return model +class TestDetBB: + """ + Tests detection models in black-box mode. + """ + + image = cv2.imread("tests/assets/blood.jpg") + + @pytest.fixture(autouse=True) + def setup(self, fxt_data_root): + self.data_dir = fxt_data_root + + @pytest.mark.parametrize("model_name", MODELS) + def test_aisedetection(self, model_name): + retrieve_otx_model(self.data_dir, model_name) + model_path = self.data_dir / "otx_models" / (model_name + ".xml") + model = ov.Core().read_model(model_path) + + preprocess_fn = get_preprocess_fn( + input_size=MODEL_CONFIGS[model_name].input_size, + hwc_to_chw=True, + ) + explainer = Explainer( + model=model, + task=Task.DETECTION, + preprocess_fn=preprocess_fn, + postprocess_fn=self.postprocess_fn, + explain_mode=ExplainMode.BLACKBOX, # defaults to AUTO + num_iterations_per_kernel=5, + divisors=[5], + ) + + target_list = [1] + explanation = explainer( + self.image, + targets=target_list, + resize=False, + colormap=False, + ) + assert explanation is not None + + target_class = target_list[0] + assert target_class in explanation.saliency_map + assert len(explanation.saliency_map) == len(target_list) + assert explanation.saliency_map[target_class].ndim == 2 + + def test_detection_visualizing(self): + model = self.get_default_model() + + preprocess_fn = get_preprocess_fn( + input_size=MODEL_CONFIGS[FAST_DET_MODEL].input_size, + hwc_to_chw=True, + ) + explainer = Explainer( + model=model, + task=Task.DETECTION, + preprocess_fn=preprocess_fn, + postprocess_fn=self.postprocess_fn, + explain_mode=ExplainMode.BLACKBOX, # defaults to AUTO + num_iterations_per_kernel=5, + divisors=[5], + ) + + target_list = [1] + explanation = explainer( + self.image, + targets=target_list, + overlay=True, + ) + assert explanation is not None + assert explanation.shape == (480, 640, 3) + + target_class = target_list[0] + assert len(explanation.saliency_map) == len(target_list) + assert target_class in explanation.saliency_map + + def test_create_aise_detection_method(self): + """Test create_white_box_detection_method.""" + model = self.get_default_model() + + preprocess_fn = get_preprocess_fn( + input_size=MODEL_CONFIGS[FAST_DET_MODEL].input_size, + hwc_to_chw=True, + ) + det_xai_method = BlackBoxMethodFactory.create_method( + Task.DETECTION, + model, + preprocess_fn, + ) + assert isinstance(det_xai_method, AISEDetection) + + def get_default_model(self): + retrieve_otx_model(self.data_dir, FAST_DET_MODEL) + model_path = self.data_dir / "otx_models" / (FAST_DET_MODEL + ".xml") + model = ov.Core().read_model(model_path) + return model + + @staticmethod + def postprocess_fn(x) -> np.ndarray: + """Returns boxes, scores, labels.""" + return x["boxes"][0][:, :4], x["boxes"][0][:, 4], x["labels"][0] + + class TestExample: """Test sanity of examples/run_detection.py.""" diff --git a/tests/unit/explanation/test_visualization.py b/tests/unit/explanation/test_visualization.py index 470a27c7..0f3bf6d8 100644 --- a/tests/unit/explanation/test_visualization.py +++ b/tests/unit/explanation/test_visualization.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from openvino_xai.common.parameters import Task from openvino_xai.common.utils import get_min_max, scaling from openvino_xai.explainer.explanation import Explanation from openvino_xai.explainer.visualizer import Visualizer, colormap, overlay, resize @@ -11,6 +12,10 @@ SALIENCY_MAPS = [ (np.random.rand(1, 5, 5) * 255).astype(np.uint8), (np.random.rand(1, 2, 5, 5) * 255).astype(np.uint8), + { + 0: (np.random.rand(5, 5) * 255).astype(np.uint8), + 1: (np.random.rand(5, 5) * 255).astype(np.uint8), + }, ] EXPLAIN_ALL_CLASSES = [ @@ -98,7 +103,7 @@ class TestVisualizer: @pytest.mark.parametrize("colormap", [True, False]) @pytest.mark.parametrize("overlay", [True, False]) @pytest.mark.parametrize("overlay_weight", [0.5, 0.3]) - def test_Visualizer( + def test_visualizer( self, saliency_maps, explain_all_classes, @@ -142,7 +147,7 @@ def test_Visualizer( for map_ in explanation.saliency_map.values(): assert map_.shape[:2] == original_input_image.shape[:2] - if saliency_maps.ndim == 3 and not overlay: + if isinstance(saliency_maps, np.ndarray) and saliency_maps.ndim == 3 and not overlay: explanation = Explanation(saliency_maps, targets=-1) visualizer = Visualizer() explanation_output_size = visualizer( @@ -157,3 +162,23 @@ def test_Visualizer( maps_data = explanation.saliency_map maps_size = explanation_output_size.saliency_map assert np.all(maps_data["per_image_map"] == maps_size["per_image_map"]) + + if isinstance(saliency_maps, dict): + metadata = { + Task.DETECTION: { + 0: ([5, 0, 7, 4], 0.5, 0), + 1: ([2, 5, 9, 7], 0.5, 0), + } + } + explanation = Explanation(saliency_maps, targets=-1, metadata=metadata) + visualizer = Visualizer() + explanation_output_size = visualizer( + explanation=explanation, + original_input_image=original_input_image, + output_size=(20, 20), + scaling=scaling, + resize=resize, + colormap=colormap, + overlay=overlay, + overlay_weight=overlay_weight, + ) diff --git a/tests/unit/methods/black_box/test_black_box_method.py b/tests/unit/methods/black_box/test_black_box_method.py index 32c5504e..c9b48e68 100644 --- a/tests/unit/methods/black_box/test_black_box_method.py +++ b/tests/unit/methods/black_box/test_black_box_method.py @@ -11,10 +11,12 @@ from openvino_xai.common.utils import retrieve_otx_model from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn -from openvino_xai.methods.black_box.aise import AISE +from openvino_xai.methods.black_box.aise.classification import AISEClassification +from openvino_xai.methods.black_box.aise.detection import AISEDetection from openvino_xai.methods.black_box.base import Preset from openvino_xai.methods.black_box.rise import RISE from tests.intg.test_classification import DEFAULT_CLS_MODEL +from tests.intg.test_detection import DEFAULT_DET_MODEL class InputSampling: @@ -26,11 +28,17 @@ class InputSampling: ) postprocess_fn = get_postprocess_fn() - def get_model(self, fxt_data_root): + def get_cls_model(self, fxt_data_root): retrieve_otx_model(fxt_data_root, DEFAULT_CLS_MODEL) model_path = fxt_data_root / "otx_models" / (DEFAULT_CLS_MODEL + ".xml") return ov.Core().read_model(model_path) + def get_det_model(self, fxt_data_root): + detection_model = "det_yolox_bccd" + retrieve_otx_model(fxt_data_root, detection_model) + model_path = fxt_data_root / "otx_models" / (detection_model + ".xml") + return ov.Core().read_model(model_path) + def _generate_with_preset(self, method, preset): _ = method.generate_saliency_map( data=self.image, @@ -38,13 +46,25 @@ def _generate_with_preset(self, method, preset): preset=preset, ) + @staticmethod + def preprocess_det_fn(x: np.ndarray) -> np.ndarray: + x = cv2.resize(src=x, dsize=(416, 416)) # OTX YOLOX + x = x.transpose((2, 0, 1)) + x = np.expand_dims(x, 0) + return x + + @staticmethod + def postprocess_det_fn(x) -> np.ndarray: + """Returns boxes, scores, labels.""" + return x["boxes"][0][:, :4], x["boxes"][0][:, 4], x["labels"][0] + -class TestAISE(InputSampling): +class TestAISEClassification(InputSampling): @pytest.mark.parametrize("target_indices", [[0], [0, 1]]) def test_run(self, target_indices, fxt_data_root: Path): - model = self.get_model(fxt_data_root) + model = self.get_cls_model(fxt_data_root) - aise_method = AISE(model, self.postprocess_fn, self.preprocess_fn) + aise_method = AISEClassification(model, self.postprocess_fn, self.preprocess_fn) saliency_map = aise_method.generate_saliency_map( data=self.image, target_indices=target_indices, @@ -70,8 +90,73 @@ def test_run(self, target_indices, fxt_data_root: Path): assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) def test_preset(self, fxt_data_root: Path): - model = self.get_model(fxt_data_root) - method = AISE(model, self.postprocess_fn, self.preprocess_fn) + model = self.get_cls_model(fxt_data_root) + method = AISEClassification(model, self.postprocess_fn, self.preprocess_fn) + + tic = time.time() + self._generate_with_preset(method, Preset.SPEED) + toc = time.time() + time_speed = toc - tic + + tic = time.time() + self._generate_with_preset(method, Preset.BALANCE) + toc = time.time() + time_balance = toc - tic + + tic = time.time() + self._generate_with_preset(method, Preset.QUALITY) + toc = time.time() + time_quality = toc - tic + + assert time_speed < time_balance < time_quality + + +class TestAISEDetection(InputSampling): + @pytest.mark.parametrize("target_indices", [[0], [0, 1]]) + def test_run(self, target_indices, fxt_data_root: Path): + model = self.get_det_model(fxt_data_root) + + aise_method = AISEDetection(model, self.postprocess_det_fn, self.preprocess_det_fn) + saliency_map = aise_method.generate_saliency_map( + data=self.image, + target_indices=target_indices, + preset=Preset.SPEED, + num_iterations_per_kernel=10, + divisors=[5], + ) + assert aise_method.num_iterations_per_kernel == 10 + assert aise_method.divisors == [5] + + assert isinstance(saliency_map, dict) + assert len(saliency_map) == len(target_indices) + for target in target_indices: + assert target in saliency_map + + ref_target = 0 + assert saliency_map[ref_target].dtype == np.uint8 + assert saliency_map[ref_target].shape == (416, 416) + assert (saliency_map[ref_target] >= 0).all() and (saliency_map[ref_target] <= 255).all() + + actual_sal_vals = saliency_map[0][150, 240:250].astype(np.int16) + ref_sal_vals = np.array([152, 168, 184, 199, 213, 225, 235, 243, 247, 249], dtype=np.uint8) + assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) + + def test_target_none(self, fxt_data_root: Path): + model = self.get_det_model(fxt_data_root) + + aise_method = AISEDetection(model, self.postprocess_det_fn, self.preprocess_det_fn) + saliency_map = aise_method.generate_saliency_map( + data=self.image, + target_indices=None, + preset=Preset.SPEED, + num_iterations_per_kernel=1, + divisors=[5], + ) + assert len(saliency_map) == 56 + + def test_preset(self, fxt_data_root: Path): + model = self.get_det_model(fxt_data_root) + method = AISEDetection(model, self.postprocess_det_fn, self.preprocess_det_fn) tic = time.time() self._generate_with_preset(method, Preset.SPEED) @@ -94,7 +179,7 @@ def test_preset(self, fxt_data_root: Path): class TestRISE(InputSampling): @pytest.mark.parametrize("target_indices", [[0], None]) def test_run(self, target_indices, fxt_data_root: Path): - model = self.get_model(fxt_data_root) + model = self.get_cls_model(fxt_data_root) rise_method = RISE(model, self.postprocess_fn, self.preprocess_fn) saliency_map = rise_method.generate_saliency_map( @@ -123,7 +208,7 @@ def test_run(self, target_indices, fxt_data_root: Path): assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) def test_preset(self, fxt_data_root: Path): - model = self.get_model(fxt_data_root) + model = self.get_cls_model(fxt_data_root) method = RISE(model, self.postprocess_fn, self.preprocess_fn) tic = time.time() diff --git a/tests/unit/methods/test_factory.py b/tests/unit/methods/test_factory.py index 987b4bda..b197c1c2 100644 --- a/tests/unit/methods/test_factory.py +++ b/tests/unit/methods/test_factory.py @@ -10,7 +10,7 @@ from openvino_xai.common.parameters import Method, Task from openvino_xai.common.utils import retrieve_otx_model from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn -from openvino_xai.methods.black_box.aise import AISE +from openvino_xai.methods.black_box.aise.classification import AISEClassification from openvino_xai.methods.factory import BlackBoxMethodFactory, WhiteBoxMethodFactory from openvino_xai.methods.white_box.activation_map import ActivationMap from openvino_xai.methods.white_box.det_class_probability_map import ( @@ -108,7 +108,7 @@ def test_create_bb_cls_vit_method(fxt_data_root: Path): model_path = fxt_data_root / "otx_models" / (VIT_MODEL + ".xml") model_vit = ov.Core().read_model(model_path) explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model_vit, get_postprocess_fn()) - assert isinstance(explain_method, AISE) + assert isinstance(explain_method, AISEClassification) def test_create_wb_det_cnn_method(fxt_data_root: Path):