-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e8434ff
commit a49334d
Showing
1 changed file
with
303 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d313573c-0644-44ca-96ef-571ebbb3c250", | ||
"metadata": {}, | ||
"source": [ | ||
"# Anomaly Detection: MNIST vs. TF Flowers\n", | ||
"The following `Jupyter Notebook` explores the use of *anomaly detection*: first training a simple *autoencoder* (the fully connected `MinNDAE` model), and exploring the *reconstruction error*." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "799a59d8-7e7f-450c-9fd0-e21749b7cd75", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setup\n", | ||
"Need to get the necessary packages ..." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9f3b8122-297d-4662-acc7-d7a7b930243d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# check for colab\n", | ||
"if \"google.colab\" in str(get_ipython()):\n", | ||
" # install colab dependencies\n", | ||
" !pip install git+https://github.com/DiogenesAnalytics/autoencoder" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "dde3750d-4828-430b-9cc8-231066c37d35", | ||
"metadata": {}, | ||
"source": [ | ||
"## Get MNIST Data\n", | ||
"Wille use `keras.datasets` to get the `MNIST` dataset, and then do some *normalizing* and *reshaping* to prepare it for the *autoencoder*." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "bc7c0488-abe6-453a-aae4-e3f9a392736a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get necessary libs for data/preprocessing\n", | ||
"import tensorflow as tf\n", | ||
"from keras.datasets import mnist\n", | ||
"\n", | ||
"# load the data\n", | ||
"(x_train, _), (x_test, _) = mnist.load_data()\n", | ||
"\n", | ||
"# preprocess the data (normalize)\n", | ||
"x_train = x_train.astype(\"float32\") / 255.\n", | ||
"x_test = x_test.astype(\"float32\") / 255.\n", | ||
"\n", | ||
"# add grayscale dimension\n", | ||
"x_train = tf.expand_dims(x_train, axis=-1)\n", | ||
"x_test = tf.expand_dims(x_test, axis=-1)\n", | ||
"\n", | ||
"# convert to tf datasets\n", | ||
"train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train))\n", | ||
"test_ds = tf.data.Dataset.from_tensor_slices((x_test, x_test))\n", | ||
"\n", | ||
"# set a few params\n", | ||
"BATCH_SIZE = 64\n", | ||
"SHUFFLE_BUFFER_SIZE = 100\n", | ||
"\n", | ||
"# update with batch/buffer size\n", | ||
"train_ds = train_ds.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)\n", | ||
"test_ds = test_ds.batch(BATCH_SIZE)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0d588a1c-a082-405a-9b55-4399ff580879", | ||
"metadata": {}, | ||
"source": [ | ||
"## Get tf_flowers Data\n", | ||
"The [TensorFlow Flowers](https://www.tensorflow.org/datasets/catalog/tf_flowers) dataset first needs to be downloaded, and then preprocessed." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "03bec295-bd65-4f3b-9149-e82b311246b8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# libs for tf flowers data\n", | ||
"import keras\n", | ||
"import pathlib\n", | ||
"\n", | ||
"# data location\n", | ||
"DATASET_URL = \"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\"\n", | ||
"\n", | ||
"# download, get path, and convert to pathlib obj\n", | ||
"TF_FLOWERS_DATA_DIR = pathlib.Path(\n", | ||
" keras.utils.get_file(\"flower_photos\", origin=DATASET_URL, untar=True, cache_dir=\"./data/keras\")\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "20ae5d2e-0209-4edb-ae66-ab5746ad278a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get keras image dataset util func\n", | ||
"from keras.utils import image_dataset_from_directory\n", | ||
"\n", | ||
"# create normalization func\n", | ||
"def normalize(x):\n", | ||
" return x / 255.\n", | ||
"\n", | ||
"# use keras util to load raw images into tensorflow.data.Dataset\n", | ||
"anomalous_data = image_dataset_from_directory(\n", | ||
" TF_FLOWERS_DATA_DIR,\n", | ||
" labels=None,\n", | ||
" color_mode=\"grayscale\",\n", | ||
" validation_split=None,\n", | ||
" shuffle=True,\n", | ||
" subset=None,\n", | ||
" seed=42,\n", | ||
" image_size=(28, 28),\n", | ||
" batch_size=3670,\n", | ||
").map(normalize)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ee2e329b-1c9d-4e09-af91-b38e93e6613f", | ||
"metadata": {}, | ||
"source": [ | ||
"## Autoencoder Training\n", | ||
"Finally the *autoencoder* can be trained ..." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2d440a72-b357-4794-a5b9-f4b4ee790524", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get libs for training ae\n", | ||
"from autoencoder.model.minimal import MinNDAE, MinNDParams\n", | ||
"\n", | ||
"# seupt config\n", | ||
"config = MinNDParams(\n", | ||
" l0={\"input_shape\": (28, 28, 1)},\n", | ||
" l2={\"units\": 32 * 1},\n", | ||
" l3={\"units\": 28 * 28 * 1},\n", | ||
" l4={\"target_shape\": (28, 28, 1)},\n", | ||
")\n", | ||
"\n", | ||
"# get ae instance\n", | ||
"autoencoder = MinNDAE(config)\n", | ||
"\n", | ||
"# check network topology\n", | ||
"autoencoder.summary()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "da883534-e913-494b-bafe-564d8b25f76c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get code for callbacks and custom loss function\n", | ||
"from autoencoder.training import build_anomaly_loss_function\n", | ||
"from keras.callbacks import EarlyStopping\n", | ||
"\n", | ||
"# create callback\n", | ||
"early_stop_callback = EarlyStopping(monitor=\"val_anomaly_diff\", patience=2)\n", | ||
"\n", | ||
"# get custom loss func\n", | ||
"custom_loss = build_anomaly_loss_function(next(iter(anomalous_data)), autoencoder)\n", | ||
"\n", | ||
"# compile ae\n", | ||
"autoencoder.compile(\n", | ||
" optimizer=\"adam\",\n", | ||
" loss=custom_loss,\n", | ||
" metrics=[custom_loss],\n", | ||
")\n", | ||
"\n", | ||
"# begin model fit\n", | ||
"autoencoder.fit(\n", | ||
" x=train_ds,\n", | ||
" epochs=10**2,\n", | ||
" validation_data=test_ds,\n", | ||
" callbacks=[early_stop_callback],\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2ece10e9-bcd3-44f9-a15d-bf2ed9b06e4f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# view training loss\n", | ||
"autoencoder.training_history()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a4629952-4d4e-411e-9841-a38e9787ca43", | ||
"metadata": {}, | ||
"source": [ | ||
"## Reconstruction Error Distribution\n", | ||
"Now let us take peak into this dataset and see how well the *autoencoder* is working as an *anomaly detector* (i.e. how **low** vs. how **high** the *reconstruction* error is for the training and anomalous datasets respectively)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a782719c-b51b-414d-a8e5-483a9efde43b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get custom anomaly detection class\n", | ||
"from autoencoder.data.anomaly import AnomalyDetector\n", | ||
"\n", | ||
"# get mnist instance\n", | ||
"mnist_recon_error = AnomalyDetector(autoencoder, test_ds, axis=(1, 2, 3))\n", | ||
"\n", | ||
"# calculate recon error\n", | ||
"mnist_recon_error.calculate_error()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "aa9a70d6-9e58-42ea-9602-27660316292f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get tf flowers instance\n", | ||
"tfflower_recon_error = AnomalyDetector(autoencoder, anomalous_data)\n", | ||
"\n", | ||
"# calculate recon error\n", | ||
"tfflower_recon_error.calculate_error()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d7d09932-87c7-4643-90ab-d20fb1174ff8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# turn on interactive plot\n", | ||
"%matplotlib widget" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3a6c9f15-2ca1-405a-930a-a944ccd21e13", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# now compare recon error distributions\n", | ||
"mnist_recon_error.histogram(\n", | ||
" \"MNIST Anomaly Detection Using TF Flowers: MinNDAE\",\n", | ||
" label=\"mnist\",\n", | ||
" bins=[100, 100],\n", | ||
" additional_data=[tfflower_recon_error], \n", | ||
" additional_labels=[\"tf_flowers\"],\n", | ||
")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.11" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |