Skip to content

Commit

Permalink
Update hyperparameters of bb methods (#65)
Browse files Browse the repository at this point in the history
* Update hyperparameters of bb methods

* fix typing

* update content and header levels
  • Loading branch information
negvet authored Sep 11, 2024
1 parent a95e871 commit 130b815
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 29 deletions.
25 changes: 15 additions & 10 deletions docs/source/user-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ Content:
- [White-Box mode](#white-box-mode)
- [Black-Box mode](#black-box-mode)
- [XAI insertion (white-box usage)](#xai-insertion-white-box-usage)
- [XAI method overview](#xai-method-overview)
- [XAI methods](#xai-methods)
- [Overview](#overview)
- [White-box methods](#white-box-methods)
- [Black-box methods](#black-box-methods)
- [Plot saliency maps](#plot-saliency-maps)
- [Saving saliency maps](#saving-saliency-maps)
- [Example scripts](#example-scripts)
Expand Down Expand Up @@ -333,7 +336,9 @@ model_xai = xai.insert_xai(
# ***** Downstream task: user's code that infers model_xai and picks 'saliency_map' output *****
```

## XAI method overview
## XAI methods

### Overview

At the moment, the following XAI methods are supported:

Expand All @@ -358,7 +363,7 @@ Target layer is the part of the model graph where XAI branch will be inserted (a

All supported methods are gradient-free, which suits deployment framework settings (e.g. OpenVINO™), where the model is in optimized or compiled representation.

## White-Box methods
### White-Box methods

When to use?
- When model architecture follows standard CNN-based or ViT-based design (OV-XAI [support](../../README.md#supported-explainable-models) 1000+ CNN and ViT models).
Expand All @@ -367,7 +372,7 @@ When to use?

All white-box methods require access to model internal state. To generate saliency map, supported white-box methods potentially change and process internal model activations in a way that fosters compute efficiency.

### Activation Map
#### Activation Map

Suitable for:
- Binary classification problems (e.g. inspecting model reasoning when predicting a positive class).
Expand All @@ -379,7 +384,7 @@ Below saliency map was obtained for [ResNet-18](https://huggingface.co/timm/resn

![OpenVINO XAI Architecture](_static/map_samples/ActivationMap_resnet18.a1_in1k_activation_map.jpg)

### Recipro-CAM (ViT Recipro-CAM for ViT models)
#### Recipro-CAM (ViT Recipro-CAM for ViT models)

Suitable for:
- Almost all CNN-based architectures.
Expand All @@ -398,7 +403,7 @@ Below saliency map was obtained for [ResNet-18](https://huggingface.co/timm/resn

![OpenVINO XAI Architecture](_static/map_samples/ReciproCAM_resnet18.a1_in1k_293.jpg)

### DetClassProbabilityMap
#### DetClassProbabilityMap

Suitable for:
- Single-stage object detection models.
Expand All @@ -412,7 +417,7 @@ Below saliency map was obtained for `YOLOX` trained in-house on PASCAL VOC datas

![OpenVINO XAI Architecture](_static/map_samples/DetClassProbabilityMap.jpg)

## Black-Box methods
### Black-Box methods

When to use?
- When custom models are used and/or white-box methods fail (e.g. Swin-based transformers).
Expand All @@ -425,7 +430,7 @@ Usually, for high quality saliency map, hundreds or thousands of model inference
Given that the quality of the saliency maps usually correlates with the number of available inferences, we propose the following presets for the black-box methods: `Preset.SPEED`, `Preset.BALANCE`, `Preset.QUALITY` (`Preset.BALANCE` is used by default).
Apart from that, methods parameters can be defined directly via Explainer or Method API.

### RISE
#### RISE

Suitable for:
- All classification models which can generate per-class prediction scores.
Expand All @@ -447,7 +452,7 @@ Below saliency map was obtained for [ResNet-18](https://huggingface.co/timm/resn

It is possible to see, that some grass-related pixels from the left cheetah also contribute to the cheetah prediction, which might indicates that model learned cheetah features in combination with grass (which makes sense).

### AISEClassification
#### AISEClassification

Suitable for:
- All classification models which can generate per-class prediction scores.
Expand All @@ -462,7 +467,7 @@ Below saliency map was obtained for [ResNet-18](https://huggingface.co/timm/resn

![OpenVINO XAI Architecture](_static/map_samples/AISE_resnet18.a1_in1k_293.jpg)

### AISEDetection
#### AISEDetection

Suitable for:
- All detection models which can generate bounding boxes, labels and scores.
Expand Down
8 changes: 5 additions & 3 deletions examples/run_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def explain_white_box(args):

# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "detection"
explanation.save(output, Path(args.image_path).stem)
output = Path(args.output) / "detection_white_box"
ori_image_name = Path(args.image_path).stem
explanation.save(output, f"{ori_image_name}_")


def explain_black_box(args):
Expand Down Expand Up @@ -131,7 +132,8 @@ def explain_black_box(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "detection_black_box"
explanation.save(output, f"{Path(args.image_path).stem}_")
ori_image_name = Path(args.image_path).stem
explanation.save(output, f"{ori_image_name}_")


def main(argv):
Expand Down
6 changes: 3 additions & 3 deletions openvino_xai/methods/black_box/aise/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,14 @@ def _preset_parameters(
kernel_widths: List[float] | np.ndarray | None,
) -> Tuple[int, np.ndarray]:
if preset == Preset.SPEED:
iterations = 25
iterations = 20
widths = np.linspace(0.1, 0.25, 3)
elif preset == Preset.BALANCE:
iterations = 50
widths = np.linspace(0.1, 0.25, 3)
elif preset == Preset.QUALITY:
iterations = 85
widths = np.linspace(0.075, 0.25, 4)
iterations = 50
widths = np.linspace(0.075, 0.25, 5)
else:
raise ValueError(f"Preset {preset} is not supported.")

Expand Down
6 changes: 3 additions & 3 deletions openvino_xai/methods/black_box/aise/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ def _preset_parameters(
divisors: List[float] | np.ndarray | None,
) -> Tuple[int, np.ndarray]:
if preset == Preset.SPEED:
iterations = 50
iterations = 20
divs = np.linspace(7, 1, 3)
elif preset == Preset.BALANCE:
iterations = 100
iterations = 50
divs = np.linspace(7, 1, 3)
elif preset == Preset.QUALITY:
iterations = 150
iterations = 50
divs = np.linspace(8, 1, 5)
else:
raise ValueError(f"Preset {preset} is not supported.")
Expand Down
27 changes: 17 additions & 10 deletions openvino_xai/methods/black_box/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def generate_saliency_map(
target_indices: List[int] | None = None,
preset: Preset = Preset.BALANCE,
num_masks: int | None = None,
num_cells: int = 8,
num_cells: int | None = None,
prob: float = 0.5,
seed: int = 0,
scale_output: bool = True,
Expand Down Expand Up @@ -84,7 +84,7 @@ def generate_saliency_map(
"""
data_preprocessed = self.preprocess_fn(data)

num_masks = self._preset_parameters(preset, num_masks)
num_masks, num_cells = self._preset_parameters(preset, num_masks, num_cells)

saliency_maps = self._run_synchronous_explanation(
data_preprocessed,
Expand All @@ -109,20 +109,27 @@ def generate_saliency_map(
def _preset_parameters(
preset: Preset,
num_masks: int | None = None,
) -> int:
# TODO (negvet): preset num_cells
if num_masks is not None:
return num_masks

num_cells: int | None = None,
) -> Tuple[int, int]:
if preset == Preset.SPEED:
return 2000
num_masks_ = 1000
num_cells_ = 4
elif preset == Preset.BALANCE:
return 5000
num_masks_ = 5000
num_cells_ = 8
elif preset == Preset.QUALITY:
return 8000
num_masks_ = 10000
num_cells_ = 12
else:
raise ValueError(f"Preset {preset} is not supported.")

if num_masks is None:
num_masks = num_masks_
if num_cells is None:
num_cells = num_cells_

return num_masks, num_cells

def _run_synchronous_explanation(
self,
data_preprocessed: np.ndarray,
Expand Down

0 comments on commit 130b815

Please sign in to comment.