Skip to content

Commit

Permalink
Beginning refactor of evaluate module: (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
DiogenesAnalytics committed Jan 7, 2024
1 parent 628ff9c commit 8d3c886
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 40 deletions.
6 changes: 3 additions & 3 deletions notebooks/demo/mnist_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,13 @@
"outputs": [],
"source": [
"# get custom class\n",
"from autoencoder.data.evaluate import AutoencoderEvaluator\n",
"from autoencoder.data.anomaly import ReconstructionError\n",
"\n",
"# get instance\n",
"ae_eval = AutoencoderEvaluator(autoencoder, test_ds, axis=(1, 2))\n",
"mnist_recon_error = ReconstructionError(autoencoder, test_ds, axis=(1, 2))\n",
"\n",
"# view distribution\n",
"ae_eval.view_error_distribution(\"MNIST Autoencoder: Reconstruction Error Distribution\")"
"mnist_recon_error.view_error_distribution(\"MNIST Autoencoder: Reconstruction Error Distribution\")"
]
}
],
Expand Down
12 changes: 5 additions & 7 deletions notebooks/demo/tf_flowers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -195,20 +195,18 @@
{
"cell_type": "code",
"execution_count": null,
"id": "60a9bac7-2902-4ea7-9c06-216584369c6d",
"metadata": {
"scrolled": true
},
"id": "e5cb8f53-0aa8-4c9d-9110-18eef56ea660",
"metadata": {},
"outputs": [],
"source": [
"# get custom class\n",
"from autoencoder.data.evaluate import AutoencoderEvaluator\n",
"from autoencoder.data.anomaly import ReconstructionError\n",
"\n",
"# get instance\n",
"ae_eval = AutoencoderEvaluator(autoencoder, x_val)\n",
"tfflower_recon_error = ReconstructionError(autoencoder, x_val)\n",
"\n",
"# view distribution\n",
"ae_eval.view_error_distribution(\"tf_flowers Autoencoder: Reconstruction Error Distribution\", bins=100)"
"tfflower_recon_error.view_error_histogram(\"tf_flowers Autoencoder: Reconstruction Error Distribution\", bins=100)"
]
}
],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tools for evaluating an autoencoder's perfomance on a dataset."""
from dataclasses import InitVar
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Generator
from typing import List
Expand All @@ -9,32 +9,44 @@
from typing import Union

import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm.auto import tqdm

from ..model.base import BaseAutoencoder
from .visualize import plot_anomalous_images
from .visualize import plot_error_distribution


@dataclass
class AutoencoderEvaluator:
"""Class for probing dataset using trained autoencoder."""
class ReconstructionError:
"""Class for generating the reconstruction error for a dataset."""

ae: Union[tf.keras.Model, BaseAutoencoder]
dataset: tf.data.Dataset
axis: Tuple[int, ...] = (1, 2, 3)
file_paths: InitVar[Optional[List[str]]] = None

def __post_init__(self) -> None:
def __post_init__(self, file_paths: Optional[List[str]]) -> None:
"""Calculate and store errors, and threshold."""
# check file paths
if file_paths is None:
# get file paths from dataset
file_paths = self.get_file_paths(self.dataset)

# get the reconstrution error
self.errors: List[float] = list(self.gen_reconstruction_error())
self.errors = pd.DataFrame(
data=self.gen_reconstruction_error(),
columns=["reconstruction_error"],
index=file_paths,
)

# store threshold
self.threshold = self.calc_95th_threshold(self.errors)
# store 95th threshold
self.threshold = self.calc_95th_threshold(
self.errors["reconstruction_error"].values.tolist()
)

@staticmethod
def has_file_paths(dataset: tf.data.Dataset) -> Optional[List[str]]:
def get_file_paths(dataset: tf.data.Dataset) -> Optional[List[str]]:
"""See if tf.data.Dataset has file_paths attribute."""
# see if tensorflow dataset has custom file_paths attr
return dataset.file_paths if hasattr(dataset, "file_paths") else None
Expand Down Expand Up @@ -90,22 +102,35 @@ def gen_reconstruction_error(self) -> Generator[Any, None, None]:
# update errors list
yield from mse.numpy()

def view_error_distribution(
self, title: str = "Reconstruction Error Distribution", bins: int = 10**3
) -> None:
"""Plot the reconstruction error distribution."""
plot_error_distribution(self.errors, self.threshold, bins, title)

def view_anomalous_images(
def view_error_histogram(
self,
output_path: Optional[Union[str, Path]] = None,
title: str = "Reconstruction Error Histogram",
bins: int = 10**3,
label: str = "threshold_source",
additional_data: Optional[List["ReconstructionError"]] = None,
additional_labels: Optional[List[str]] = None,
) -> None:
"""Plot anomalous image inputs with their reconstructed outputs."""
# Use the trained autoencoder to predict and calculate reconstruction error
for batch_idx, (inputs, predictions, mse) in enumerate(
self.gen_batch_predictions()
):
# now build, display, and optionally save images
plot_anomalous_images(
inputs, predictions, mse, self.threshold, batch_idx, output_path
)
"""Plot the reconstruction error distribution."""
# setup list of data and labels
error_data = [self.errors["reconstruction_error"].values.tolist()]

# check for more data
if additional_data is not None:
# add in other data supplied
error_data += [
ds.errors["reconstruction_error"].tolist() for ds in additional_data
]

# chek for more labels
if additional_labels is not None:
# get error labels
error_labels = [label] + additional_labels

else:
# otherwise don't use any
error_labels = None

# now plot
plot_error_distribution(
error_data, self.threshold, bins, title, labels=error_labels
)
16 changes: 12 additions & 4 deletions src/autoencoder/data/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -58,14 +57,19 @@ def compare_image_predictions(


def plot_error_distribution(
errors: Union[Tuple[float, ...], List[float]],
errors: List[List[float]],
threshold: float,
bins: int,
title: str,
density: bool = False,
labels: Optional[List[str]] = None,
) -> None:
"""Plot a simple histogram for the reconstruction error."""
# create histogram plot
plt.hist(errors, bins=bins)
# calculate alpha
alpha = 0.5 if len(errors) > 1 else None

# build histogram
plt.hist(x=errors, alpha=alpha, bins=bins, density=density)

# add title
plt.title(title)
Expand All @@ -74,6 +78,10 @@ def plot_error_distribution(
plt.xlabel("Reconstruction Error")
plt.ylabel("# Samples")

# set legend
if labels is not None:
plt.legend(labels)

# plotting threshold
plt.axvline(x=threshold, color="r", linestyle="dashed", linewidth=2)

Expand Down

0 comments on commit 8d3c886

Please sign in to comment.