Skip to content

Commit

Permalink
Updated recon error histogram for multi alphas/bins
Browse files Browse the repository at this point in the history
  • Loading branch information
DiogenesAnalytics committed Jan 7, 2024
1 parent 8d3c886 commit 0be245c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 15 deletions.
2 changes: 1 addition & 1 deletion notebooks/demo/mnist_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@
"mnist_recon_error = ReconstructionError(autoencoder, test_ds, axis=(1, 2))\n",
"\n",
"# view distribution\n",
"mnist_recon_error.view_error_distribution(\"MNIST Autoencoder: Reconstruction Error Distribution\")"
"mnist_recon_error.histogram(\"MNIST Autoencoder: Reconstruction Error Distribution\")"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion notebooks/demo/tf_flowers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@
"tfflower_recon_error = ReconstructionError(autoencoder, x_val)\n",
"\n",
"# view distribution\n",
"tfflower_recon_error.view_error_histogram(\"tf_flowers Autoencoder: Reconstruction Error Distribution\", bins=100)"
"tfflower_recon_error.histogram(\"tf_flowers Autoencoder: Reconstruction Error Distribution\", bins=[100])"
]
}
],
Expand Down
76 changes: 68 additions & 8 deletions src/autoencoder/data/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,16 @@ def gen_reconstruction_error(self) -> Generator[Any, None, None]:
# update errors list
yield from mse.numpy()

def view_error_histogram(
def _plot_error_distribution(
self,
title: str = "Reconstruction Error Histogram",
bins: int = 10**3,
label: str = "threshold_source",
additional_data: Optional[List["ReconstructionError"]] = None,
additional_labels: Optional[List[str]] = None,
title: str,
label: str,
density: bool,
additional_data: Optional[List["ReconstructionError"]],
additional_labels: Optional[List[str]],
alphas: Optional[List[float]],
bins: Optional[List[int]],
) -> None:
"""Plot the reconstruction error distribution."""
# setup list of data and labels
error_data = [self.errors["reconstruction_error"].values.tolist()]

Expand All @@ -130,7 +131,66 @@ def view_error_histogram(
# otherwise don't use any
error_labels = None

# check for alphas
if alphas is None:
# determine alpha value
alph_val = 0.5 if len(error_data) > 1 else 1

# build alphas list
alphas = [alph_val] * len(error_data)

# check for bins
if bins is None:
# build default bins list
bins = [10**3] * len(error_data)

# now plot
plot_error_distribution(
error_data, self.threshold, bins, title, labels=error_labels
errors=error_data,
threshold=self.threshold,
title=title,
bins=bins,
alphas=alphas,
labels=error_labels,
density=density,
)

def histogram(
self,
title: str = "Reconstruction Error Histogram",
label: str = "threshold_source",
additional_data: Optional[List["ReconstructionError"]] = None,
additional_labels: Optional[List[str]] = None,
alphas: Optional[List[float]] = None,
bins: Optional[List[int]] = None,
) -> None:
"""Plot the reconstruction error as a histogram."""
self._plot_error_distribution(
title=title,
bins=bins,
label=label,
additional_data=additional_data,
additional_labels=additional_labels,
density=False,
alphas=alphas,
)

def probability_distribution(
self,
title: str = "Reconstruction Error Probability Distribution",
label: str = "threshold_source",
additional_data: Optional[List["ReconstructionError"]] = None,
additional_labels: Optional[List[str]] = None,
alphas: Optional[List[float]] = None,
bins: Optional[List[int]] = None,
) -> None:
"""Plot the reconstruction error as a probability distribution."""
self._plot_error_distribution(
title=title,
bins=bins,
label=label,
additional_data=additional_data,
additional_labels=additional_labels,
density=True,
alphas=alphas,
)
13 changes: 8 additions & 5 deletions src/autoencoder/data/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,20 @@ def compare_image_predictions(
def plot_error_distribution(
errors: List[List[float]],
threshold: float,
bins: int,
title: str,
bins: List[int],
alphas: List[float],
density: bool = False,
labels: Optional[List[str]] = None,
) -> None:
"""Plot a simple histogram for the reconstruction error."""
# calculate alpha
alpha = 0.5 if len(errors) > 1 else None

# build histogram
plt.hist(x=errors, alpha=alpha, bins=bins, density=density)
for (
err_data,
alph,
bn,
) in zip(errors, alphas, bins, strict=True):
plt.hist(x=err_data, alpha=alph, bins=bn, density=density, stacked=False)

# add title
plt.title(title)
Expand Down

0 comments on commit 0be245c

Please sign in to comment.