Skip to content

Commit

Permalink
Unify postprocess functions (#60)
Browse files Browse the repository at this point in the history
* unify postprocess fn + output check

* black + isort

* Fix tests
  • Loading branch information
negvet authored Aug 30, 2024
1 parent 656345c commit 79c0677
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 15 deletions.
4 changes: 3 additions & 1 deletion examples/run_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ def preprocess_fn(x: np.ndarray) -> np.ndarray:

def postprocess_fn(x) -> np.ndarray:
"""Returns boxes, scores, labels."""
return x["boxes"][0][:, :4], x["boxes"][0][:, 4], x["labels"][0]
return x["boxes"][:, :, :4], x["boxes"][:, :, 4], x["labels"]


def explain_white_box(args):
"""
White-box scenario.
Per-class saliency map generation for single-stage detection models (using DetClassProbabilityMap).
Insertion of the XAI branch into the model, thus model has additional 'saliency_map' output.
"""

Expand Down Expand Up @@ -95,6 +96,7 @@ def explain_white_box(args):
def explain_black_box(args):
"""
Black-box scenario.
Per-box saliency map generation for all detection models (using AISEDetection).
"""

# Create ov.Model
Expand Down
10 changes: 6 additions & 4 deletions openvino_xai/methods/black_box/aise/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
)
from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.methods.black_box.utils import check_classification_output


class AISEClassification(AISEBase):
"""
AISE for classification models.
postprocess_fn expected to return one container with scores. Without batch dim.
postprocess_fn expected to return one container with scores. With batch dimention equals to one.
:param model: OpenVINO model.
:type model: ov.Model
Expand Down Expand Up @@ -144,11 +145,12 @@ def _preset_parameters(
kernel_widths = widths
return num_iterations_per_kernel, kernel_widths

def _get_loss(self, data_perturbed: np.array) -> float:
def _get_loss(self, data_perturbed: np.ndarray) -> float:
"""Get loss for perturbed input."""
x = self.model_forward(data_perturbed, preprocess=False)
x = self.postprocess_fn(x)
check_classification_output(x)

if np.max(x) > 1 or np.min(x) < 0:
x = sigmoid(x)
pred_scores = x.squeeze() # type: ignore
return pred_scores[self.target]
return x[0][self.target]
17 changes: 11 additions & 6 deletions openvino_xai/methods/black_box/aise/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
)
from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.methods.black_box.utils import check_detection_output


class AISEDetection(AISEBase):
"""
AISE for detection models.
postprocess_fn expected to return three containers: boxes (format: [x1, y1, x2, y2]), scores, labels. Without batch dim.
postprocess_fn expected to return three containers: boxes (format: [x1, y1, x2, y2]), scores, labels. With batch dimention equals to one.
:param model: OpenVINO model.
:type model: ov.Model
Expand Down Expand Up @@ -93,8 +94,11 @@ def generate_saliency_map( # type: ignore
self.data_preprocessed = self.preprocess_fn(data)
forward_output = self.model_forward(self.data_preprocessed, preprocess=False)

# postprocess_fn expected to return three containers: boxes (x1, y1, x2, y2), scores, labels, without batch dim.
boxes, scores, labels = self.postprocess_fn(forward_output)
# postprocess_fn expected to return three containers: boxes (x1, y1, x2, y2), scores, labels.
output = self.postprocess_fn(forward_output)
check_detection_output(output)
boxes, scores, labels = output
boxes, scores, labels = boxes[0], scores[0], labels[0]

if target_indices is None:
num_boxes = len(boxes)
Expand Down Expand Up @@ -181,12 +185,13 @@ def _process_box(self, padding_coef: float = 0.5) -> None:
def _get_loss(self, data_perturbed: np.array) -> float:
"""Get loss for perturbed input."""
forward_output = self.model_forward(data_perturbed, preprocess=False)
boxes, pred_scores, labels = self.postprocess_fn(forward_output)
boxes, scores, labels = self.postprocess_fn(forward_output)
boxes, scores, labels = boxes[0], scores[0], labels[0]

loss = 0
for box, pred_score, label in zip(boxes, pred_scores, labels):
for box, score, label in zip(boxes, scores, labels):
if label == self.target_label:
loss = max(loss, self._iou(self.target_box, box) * pred_score)
loss = max(loss, self._iou(self.target_box, box) * score)
return loss

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions openvino_xai/methods/black_box/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,24 @@
import openvino.runtime as ov

from openvino_xai.methods.base import MethodBase
from openvino_xai.methods.black_box.utils import check_classification_output


class BlackBoxXAIMethod(MethodBase):
"""Base class for methods that explain model in Black-Box mode."""

def prepare_model(self, load_model: bool = True) -> ov.Model:
"""Load model prior to inference."""
if load_model:
self.load_model()
return self._model

def get_num_classes(self, data_preprocessed):
"""Estimates number of classes for the classification model. Expects batch dimention."""
forward_output = self.model_forward(data_preprocessed, preprocess=False)
logits = self.postprocess_fn(forward_output)
_, num_classes = logits.shape
return num_classes
check_classification_output(logits)
return logits.shape[1]


class Preset(Enum):
Expand Down
4 changes: 4 additions & 0 deletions openvino_xai/methods/black_box/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@

from openvino_xai.common.utils import IdentityPreprocessFN, is_bhwc_layout, scaling
from openvino_xai.methods.black_box.base import BlackBoxXAIMethod, Preset
from openvino_xai.methods.black_box.utils import check_classification_output


class RISE(BlackBoxXAIMethod):
"""RISE explains classification models in black-box mode using
'RISE: Randomized Input Sampling for Explanation of Black-box Models' paper
(https://arxiv.org/abs/1806.07421).
postprocess_fn expected to return one container with scores. With batch dimention equals to one.
:param model: OpenVINO model.
:type model: ov.Model
:param postprocess_fn: Post-processing function that extract scores from IR model output.
Expand Down Expand Up @@ -149,6 +152,7 @@ def _run_synchronous_explanation(

forward_output = self.model_forward(masked, preprocess=False)
raw_scores = self.postprocess_fn(forward_output)
check_classification_output(raw_scores)

sal = self._get_scored_mask(raw_scores, mask, target_classes)
saliency_maps += sal
Expand Down
39 changes: 39 additions & 0 deletions openvino_xai/methods/black_box/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Tuple

import numpy as np


def check_classification_output(x: np.ndarray) -> None:
"""Checks output of the postprocess function provided by the user (for classification talk)."""
if not isinstance(x, np.ndarray):
raise RuntimeError("Postprocess function should return numpy array.")
if x.ndim != 2 or x.shape[0] != 1:
raise RuntimeError("Postprocess function should return two dimentional numpy array with batch size of 1.")


def check_detection_output(x: Tuple[np.ndarray, np.ndarray, np.ndarray]) -> None:
"""Checks output of the postprocess function provided by the user (for detection task)."""
if not hasattr(x, "__len__"):
raise RuntimeError("Postprocess function should return sized object.")

if len(x) != 3:
raise RuntimeError(
"Postprocess function should return three containers: boxes (format: [x1, y1, x2, y2]), scores, labels."
)

for item in x:
if not isinstance(item, np.ndarray):
raise RuntimeError("Postprocess function should return numpy arrays.")
if item.shape[0] != 1:
raise RuntimeError("Postprocess function should return numpy arrays with batch size of 1.")

boxes, scores, labels = x
if boxes.ndim != 3:
raise RuntimeError("Boxes should be three-dimentional [Batch, NumBoxes, BoxCoords].")
if scores.ndim != 2:
raise RuntimeError("Scores should be two-dimentional [Batch, Scores].")
if labels.ndim != 2:
raise RuntimeError("Labels should be two-dimentional [Batch, Labels].")
2 changes: 1 addition & 1 deletion tests/intg/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def get_default_model(self):
@staticmethod
def postprocess_fn(x) -> np.ndarray:
"""Returns boxes, scores, labels."""
return x["boxes"][0][:, :4], x["boxes"][0][:, 4], x["labels"][0]
return x["boxes"][:, :, :4], x["boxes"][:, :, 4], x["labels"]


class TestExample:
Expand Down
58 changes: 57 additions & 1 deletion tests/unit/methods/black_box/test_black_box_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from openvino_xai.methods.black_box.aise.detection import AISEDetection
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.methods.black_box.rise import RISE
from openvino_xai.methods.black_box.utils import (
check_classification_output,
check_detection_output,
)
from tests.intg.test_classification import DEFAULT_CLS_MODEL
from tests.intg.test_detection import DEFAULT_DET_MODEL

Expand Down Expand Up @@ -56,7 +60,7 @@ def preprocess_det_fn(x: np.ndarray) -> np.ndarray:
@staticmethod
def postprocess_det_fn(x) -> np.ndarray:
"""Returns boxes, scores, labels."""
return x["boxes"][0][:, :4], x["boxes"][0][:, 4], x["labels"][0]
return x["boxes"][:, :, :4], x["boxes"][:, :, 4], x["labels"]


class TestAISEClassification(InputSampling):
Expand Down Expand Up @@ -227,3 +231,55 @@ def test_preset(self, fxt_data_root: Path):
time_quality = toc - tic

assert time_speed < time_balance < time_quality


def test_check_classification_output():
with pytest.raises(Exception) as exc_info:
x = 1
check_classification_output(x)
assert str(exc_info.value) == "Postprocess function should return numpy array."

with pytest.raises(Exception) as exc_info:
x = np.zeros((2, 2, 2))
check_classification_output(x)
assert str(exc_info.value) == "Postprocess function should return two dimentional numpy array with batch size of 1."


def test_check_detection_output():
with pytest.raises(Exception) as exc_info:
x = 1
check_detection_output(x)
assert str(exc_info.value) == "Postprocess function should return sized object."

with pytest.raises(Exception) as exc_info:
x = 1, 2
check_detection_output(x)
assert (
str(exc_info.value)
== "Postprocess function should return three containers: boxes (format: [x1, y1, x2, y2]), scores, labels."
)

with pytest.raises(Exception) as exc_info:
x = np.array([1]), np.array([1]), 1
check_detection_output(x)
assert str(exc_info.value) == "Postprocess function should return numpy arrays."

with pytest.raises(Exception) as exc_info:
x = np.ones((1, 2)), np.ones((1, 2)), np.ones((2, 2))
check_detection_output(x)
assert str(exc_info.value) == "Postprocess function should return numpy arrays with batch size of 1."

with pytest.raises(Exception) as exc_info:
x = np.ones((1, 2)), np.ones((1)), np.ones((1, 2, 3))
check_detection_output(x)
assert str(exc_info.value) == "Boxes should be three-dimentional [Batch, NumBoxes, BoxCoords]."

with pytest.raises(Exception) as exc_info:
x = np.ones((1, 2, 4)), np.ones((1)), np.ones((1, 2, 3))
check_detection_output(x)
assert str(exc_info.value) == "Scores should be two-dimentional [Batch, Scores]."

with pytest.raises(Exception) as exc_info:
x = np.ones((1, 2, 4)), np.ones((1, 2)), np.ones((1, 2, 3))
check_detection_output(x)
assert str(exc_info.value) == "Labels should be two-dimentional [Batch, Labels]."

0 comments on commit 79c0677

Please sign in to comment.