Skip to content

Commit

Permalink
Adding reconstruction error section
Browse files Browse the repository at this point in the history
  • Loading branch information
DiogenesAnalytics committed Jan 2, 2024
1 parent c0b983a commit 9e7ae74
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 15 deletions.
78 changes: 63 additions & 15 deletions notebooks/demo/mnist_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"outputs": [],
"source": [
"# get necessary libs for data/preprocessing\n",
"import tensorflow as tf\n",
"from keras.datasets import mnist\n",
"\n",
"# load the data\n",
Expand All @@ -56,8 +57,18 @@
"# preprocess the data (normalize)\n",
"x_train = x_train.astype(\"float32\") / 255.\n",
"x_test = x_test.astype(\"float32\") / 255.\n",
"print(f\"Train shape: {x_train.shape}\")\n",
"print(f\"Test shape: {x_test.shape}\")"
"\n",
"# convert to tf datasets\n",
"train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train))\n",
"test_ds = tf.data.Dataset.from_tensor_slices((x_test, x_test))\n",
"\n",
"# set a few params\n",
"BATCH_SIZE = 64\n",
"SHUFFLE_BUFFER_SIZE = 100\n",
"\n",
"# update with batch/buffer size\n",
"train_ds = train_ds.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)\n",
"test_ds = test_ds.batch(BATCH_SIZE)"
]
},
{
Expand All @@ -77,12 +88,20 @@
"outputs": [],
"source": [
"# get libs for training ae\n",
"from autoencoder.model.minimal import Min2DAE\n",
"from autoencoder.model.minimal import Min2DAE, Min2DParams\n",
"\n",
"# seupt config\n",
"config = Min2DParams(\n",
" l0={\"input_shape\": (28, 28)},\n",
" l2={\"units\": 32 * 1},\n",
" l3={\"units\": 28 * 28 * 1},\n",
" l4={\"target_shape\": (28, 28)},\n",
")\n",
"\n",
"# get ae instance\n",
"autoencoder = Min2DAE()\n",
"autoencoder = Min2DAE(config)\n",
"\n",
"# check model topology\n",
"# check network topology\n",
"autoencoder.summary()"
]
},
Expand All @@ -103,13 +122,10 @@
"autoencoder.compile(optimizer=\"adam\", loss=\"binary_crossentropy\")\n",
"\n",
"# begin model fit\n",
"history = autoencoder.fit(\n",
" x=x_train,\n",
" y=x_train,\n",
"autoencoder.fit(\n",
" x=train_ds,\n",
" epochs=50,\n",
" batch_size=256,\n",
" shuffle=True,\n",
" validation_data=(x_test, x_test),\n",
" validation_data=test_ds,\n",
" callbacks=[early_stop_callback],\n",
")"
]
Expand Down Expand Up @@ -137,18 +153,50 @@
{
"cell_type": "code",
"execution_count": null,
"id": "1ea208f6-b50a-405d-b07a-ad86736c0a43",
"id": "d14a4180-ca54-437d-abc4-21427f45eed2",
"metadata": {},
"outputs": [],
"source": [
"# get viz func\n",
"from autoencoder.data import compare_image_predictions\n",
"\n",
"# get decoded images\n",
"decoded_imgs = autoencoder.predict(x=x_train)\n",
"# get samples from validation dataset\n",
"val_samples = test_ds.take(1)\n",
"\n",
"# get raw numpy arrays\n",
"val_input = [item for pair in val_samples.as_numpy_iterator() for item in pair[0]]\n",
"\n",
"# and decoded\n",
"decoded_imgs = autoencoder.predict(x=val_samples)\n",
"\n",
"# display\n",
"compare_image_predictions(x_train, decoded_imgs)"
"compare_image_predictions(val_input, decoded_imgs)"
]
},
{
"cell_type": "markdown",
"id": "e299de94-d802-4d3d-82c5-875ddde80232",
"metadata": {},
"source": [
"## Reconstruction Error Distribution\n",
"Now let us take peak into this dataset and see how well the *autoencoder* is capturing the image features."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60a9bac7-2902-4ea7-9c06-216584369c6d",
"metadata": {},
"outputs": [],
"source": [
"# get custom class\n",
"from autoencoder.data.evaluate import AutoencoderEvaluator\n",
"\n",
"# get instance\n",
"ae_eval = AutoencoderEvaluator(autoencoder, test_ds, axis=(1, 2))\n",
"\n",
"# view distribution\n",
"ae_eval.view_error_distribution(\"MNIST Autoencoder: Reconstruction Error Distribution\")"
]
}
],
Expand Down
28 changes: 28 additions & 0 deletions notebooks/demo/tf_flowers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,34 @@
"# display\n",
"compare_image_predictions(val_input, decoded_imgs)"
]
},
{
"cell_type": "markdown",
"id": "672e0631-950c-44f8-9f7a-9268062c1a55",
"metadata": {},
"source": [
"## Reconstruction Error Distribution\n",
"Now let us take peak into this dataset and see how well the *autoencoder* is capturing the image features."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60a9bac7-2902-4ea7-9c06-216584369c6d",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# get custom class\n",
"from autoencoder.data.evaluate import AutoencoderEvaluator\n",
"\n",
"# get instance\n",
"ae_eval = AutoencoderEvaluator(autoencoder, x_val)\n",
"\n",
"# view distribution\n",
"ae_eval.view_error_distribution(\"tf_flowers Autoencoder: Reconstruction Error Distribution\")"
]
}
],
"metadata": {
Expand Down

0 comments on commit 9e7ae74

Please sign in to comment.