Skip to content

Commit

Permalink
Add metrics to Anomaly OV Task (#3471)
Browse files Browse the repository at this point in the history
add metrics to openvino model

Signed-off-by: Ashwin Vaidya <[email protected]>
  • Loading branch information
ashwinvaidya17 authored May 9, 2024
1 parent 3f32845 commit c871af4
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 31 deletions.
86 changes: 81 additions & 5 deletions src/otx/algo/anomaly/openvino_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,93 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Sequence

import numpy as np
import torch
from anomalib.metrics import create_metric_collection
from lightning import Callback, Trainer
from torchvision.transforms.functional import resize

from otx.core.data.entity.anomaly import AnomalyClassificationDataBatch
from otx.core.metrics.types import MetricCallable, NullMetricCallable
from otx.core.model.anomaly import AnomalyModelInputs
from otx.core.model.base import OVModel

if TYPE_CHECKING:
from anomalib.metrics import AnomalibMetricCollection
from model_api.models import Model
from model_api.models.anomaly import AnomalyResult


class _OVMetricCallback(Callback):
def __init__(self) -> None:
super().__init__()

def on_test_epoch_start(self, trainer: Trainer, pl_module: AnomalyOpenVINO) -> None:
pl_module.image_metrics.reset()
pl_module.pixel_metrics.reset()

def on_test_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyOpenVINO,
outputs: list[AnomalyResult],
batch: AnomalyModelInputs,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
# Convert modelAPI scores to anomaly scores. i.e flip scores with Normal label.
score_dict = {
"pred_scores": torch.tensor(
[output.pred_score if output.pred_label == "Anomaly" else 1 - output.pred_score for output in outputs],
),
"labels": torch.tensor(batch.labels) if batch.batch_size == 1 else torch.vstack(batch.labels),
}
if not isinstance(batch, AnomalyClassificationDataBatch):
score_dict["anomaly_maps"] = torch.tensor(np.array([output.anomaly_map for output in outputs])) / 255.0
score_dict["masks"] = batch.masks if batch.batch_size == 1 else torch.vstack(batch.masks)
# resize masks and anomaly maps to 256,256 as this is the size used in Anomalib
score_dict["masks"] = resize(score_dict["masks"], (256, 256))
score_dict["anomaly_maps"] = resize(score_dict["anomaly_maps"], (256, 256))

self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, score_dict)

def on_test_epoch_end(self, trainer: Trainer, pl_module: AnomalyOpenVINO) -> None:
self._log_metrics(pl_module)

def _update_metrics(
self,
image_metric: AnomalibMetricCollection,
pixel_metric: AnomalibMetricCollection,
outputs: dict[str, torch.Tensor],
) -> None:
"""Update performance metrics."""
image_metric.update(outputs["pred_scores"], outputs["labels"].int())
if "masks" in outputs and "anomaly_maps" in outputs:
pixel_metric.update(outputs["anomaly_maps"], outputs["masks"].int())

@staticmethod
def _log_metrics(pl_module: AnomalyOpenVINO) -> None:
"""Log computed performance metrics."""
if pl_module.pixel_metrics._update_called: # noqa: SLF001
pl_module.log_dict(pl_module.pixel_metrics, prog_bar=True)
pl_module.log_dict(pl_module.image_metrics, prog_bar=False)
else:
pl_module.log_dict(pl_module.image_metrics, prog_bar=True)


class AnomalyOpenVINO(OVModel):
"""Anomaly OpenVINO model."""

# [TODO](ashwinvaidya17): Remove LightningModule once OTXModel is updated to use LightningModule.
# NOTE: Ideally OVModel should not be a LightningModule

def __init__(
self,
model_name: str,
async_inference: bool = True,
max_num_requests: int | None = None,
use_throughput_mode: bool = True,
model_api_configuration: dict[str, Any] | None = None,
metric: MetricCallable = NullMetricCallable,
metric: MetricCallable = NullMetricCallable, # Metrics is computed using Anomalib's metric
**kwargs,
) -> None:
super().__init__(
Expand All @@ -47,6 +109,9 @@ def __init__(
model_api_configuration=model_api_configuration,
metric=metric,
)
metric_names = ["AUROC", "F1Score"]
self.image_metrics: AnomalibMetricCollection = create_metric_collection(metric_names, prefix="image_")
self.pixel_metrics: AnomalibMetricCollection = create_metric_collection(metric_names, prefix="pixel_")

def _create_model(self) -> Model:
from model_api.adapters import OpenvinoAdapter, create_core, get_user_config
Expand All @@ -68,6 +133,10 @@ def _create_model(self) -> Model:
configuration=self.model_api_configuration,
)

def configure_callbacks(self) -> Sequence[Callback] | Callback:
"""Return the metric callback."""
return _OVMetricCallback()

def test_step(self, inputs: AnomalyModelInputs, batch_idx: int) -> list[AnomalyResult]:
"""Return outputs from the OpenVINO model."""
return self.forward(inputs) # type: ignore[return-value]
Expand All @@ -79,3 +148,10 @@ def predict_step(self, inputs: AnomalyModelInputs, batch_idx: int) -> list[Anoma
def _customize_outputs(self, outputs: list[AnomalyResult], inputs: AnomalyModelInputs) -> list[AnomalyResult]:
"""Return outputs from the OpenVINO model as is."""
return outputs

def _customize_inputs(self, inputs: AnomalyModelInputs) -> dict[str, np.ndarray]:
"""Return inputs as is."""
inputs = super()._customize_inputs(inputs)
# model needs inputs in range 0-1
inputs["inputs"] = [value / 255.0 for value in inputs["inputs"]]
return inputs
44 changes: 18 additions & 26 deletions src/otx/core/model/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from torchmetrics import Metric
from torchvision.transforms.v2 import Transform


AnomalyModelInputs: TypeAlias = (
Expand Down Expand Up @@ -78,7 +77,7 @@ def __init__(
input_size=(1, 3, *image_shape),
mean=mean_values,
std=scale_values,
swap_rgb=False, # default value. Ideally, modelAPI should pass RGB inputs after the pre-processing step
swap_rgb=True, # BGR -> RGB
)

@property
Expand Down Expand Up @@ -145,8 +144,6 @@ def __init__(self) -> None:
self.optimizer: list[OptimizerCallable] | OptimizerCallable = None
self.scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = None
self._input_size: tuple[int, int] = (256, 256)
self.mean_values: tuple[float, float, float] = (0.0, 0.0, 0.0)
self.scale_values: tuple[float, float, float] = (1.0, 1.0, 1.0)
self.trainer: Trainer
self.model: nn.Module
self.image_threshold: BaseThreshold
Expand All @@ -161,14 +158,12 @@ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on saving checkpoint."""
super().on_save_checkpoint(checkpoint) # type: ignore[misc]

attrs = ["_task_type", "_input_size", "mean_values", "scale_values", "image_threshold", "pixel_threshold"]

attrs = ["_task_type", "_input_size", "image_threshold", "pixel_threshold"]
checkpoint["anomaly"] = {key: getattr(self, key, None) for key in attrs}

def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on loading checkpoint."""
super().on_load_checkpoint(checkpoint) # type: ignore[misc]

if anomaly_attrs := checkpoint.get("anomaly"):
for key, value in anomaly_attrs.items():
setattr(self, key, value)
Expand Down Expand Up @@ -206,15 +201,21 @@ def task(self, value: OTXTaskType) -> None:
msg = f"Unexpected task type: {value}"
raise ValueError(msg)

def _extract_mean_scale_from_transforms(self, transforms: list[Transform]) -> None:
"""Extract mean and scale values from transforms."""
for transform in transforms:
def _get_values_from_transforms(
self,
key_name: str,
) -> tuple:
"""Get the value requested value from default transforms."""
for transform in self.configure_transforms().transforms: # type: ignore[attr-defined]
name = transform.__class__.__name__
if "Resize" in name:
self.input_size = transform.size * 2 # transform.size has value [size], so *2 gives (size, size)
if "Resize" in name and key_name == "input_size":
image_size = transform.size
elif "Normalize" in name:
self.mean_values = transform.mean
self.scale_values = transform.std
if key_name == "mean":
mean_value = transform.mean
elif key_name == "scale":
std_value = transform.std
return image_size, mean_value, std_value

@property
def trainable_model(self) -> str | None:
Expand All @@ -228,15 +229,6 @@ def trainable_model(self) -> str | None:
"""
return None

def setup(self, stage: str | None = None) -> None:
"""Setup the model."""
super().setup(stage) # type: ignore[misc]
if stage == "fit" and hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "config"):
if hasattr(self.trainer.datamodule.config, "test_subset"):
self._extract_mean_scale_from_transforms(self.trainer.datamodule.config.test_subset.transforms)
elif hasattr(self.trainer.datamodule.config, "val_subset"):
self._extract_mean_scale_from_transforms(self.trainer.datamodule.config.val_subset.transforms)

def configure_callbacks(self) -> list[Callback]:
"""Get all necessary callbacks required for training and post-processing on Anomalib models."""
image_metrics = ["AUROC", "F1Score"]
Expand Down Expand Up @@ -414,14 +406,14 @@ def export(
"""
min_val = self.normalization_metrics.state_dict()["min"].cpu().numpy().tolist()
max_val = self.normalization_metrics.state_dict()["max"].cpu().numpy().tolist()
image_shape = (256, 256) if self.input_size is None else self.input_size
image_shape, mean_values, scale_values = self._get_values_from_transforms("input_size")
exporter = _AnomalyModelExporter(
image_shape=image_shape,
image_threshold=self.image_threshold.value.cpu().numpy().tolist(),
pixel_threshold=self.pixel_threshold.value.cpu().numpy().tolist(),
task=self.task,
mean_values=self.mean_values,
scale_values=self.scale_values,
mean_values=mean_values,
scale_values=scale_values,
normalization_scale=max_val - min_val,
)
return exporter.export(
Expand Down

0 comments on commit c871af4

Please sign in to comment.