diff --git a/tensorflow_v2/notebooks/3_NeuralNetworks/convolutional_network.ipynb b/tensorflow_v2/notebooks/3_NeuralNetworks/convolutional_network.ipynb index 0bf52a43..619a6132 100644 --- a/tensorflow_v2/notebooks/3_NeuralNetworks/convolutional_network.ipynb +++ b/tensorflow_v2/notebooks/3_NeuralNetworks/convolutional_network.ipynb @@ -1,400 +1,501 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Convolutional Neural Network Example\n", - "\n", - "Build a convolutional neural network with TensorFlow v2.\n", - "\n", - "This example is using a low-level approach to better understand all mechanics behind building convolutional neural networks and the training process.\n", - "\n", - "- Author: Aymeric Damien\n", - "- Project: https://github.com/aymericdamien/TensorFlow-Examples/" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## CNN Overview\n", - "\n", - "![CNN](http://personal.ie.cuhk.edu.hk/~ccloy/project_target_code/images/fig3.png)\n", - "\n", - "## MNIST Dataset Overview\n", - "\n", - "This example is using MNIST handwritten digits. The dataset contains 60,000 examples for training and 10,000 examples for testing. The digits have been size-normalized and centered in a fixed-size image (28x28 pixels) with values from 0 to 255. \n", - "\n", - "In this example, each image will be converted to float32 and normalized to [0, 1].\n", - "\n", - "![MNIST Dataset](http://neuralnetworksanddeeplearning.com/images/mnist_100_digits.png)\n", - "\n", - "More info: http://yann.lecun.com/exdb/mnist/" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import absolute_import, division, print_function\n", - "\n", - "import tensorflow as tf\n", - "from tensorflow.keras import Model, layers\n", - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# MNIST dataset parameters.\n", - "num_classes = 10 # total classes (0-9 digits).\n", - "\n", - "# Training parameters.\n", - "learning_rate = 0.001\n", - "training_steps = 200\n", - "batch_size = 128\n", - "display_step = 10\n", - "\n", - "# Network parameters.\n", - "conv1_filters = 32 # number of filters for 1st conv layer.\n", - "conv2_filters = 64 # number of filters for 2nd conv layer.\n", - "fc1_units = 1024 # number of neurons for 1st fully-connected layer." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# Prepare MNIST data.\n", - "from tensorflow.keras.datasets import mnist\n", - "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", - "# Convert to float32.\n", - "x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)\n", - "# Normalize images value from [0, 255] to [0, 1].\n", - "x_train, x_test = x_train / 255., x_test / 255." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# Use tf.data API to shuffle and batch data.\n", - "train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", - "train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "# Create TF Model.\n", - "class ConvNet(Model):\n", - " # Set layers.\n", - " def __init__(self):\n", - " super(ConvNet, self).__init__()\n", - " # Convolution Layer with 32 filters and a kernel size of 5.\n", - " self.conv1 = layers.Conv2D(32, kernel_size=5, activation=tf.nn.relu)\n", - " # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. \n", - " self.maxpool1 = layers.MaxPool2D(2, strides=2)\n", - "\n", - " # Convolution Layer with 64 filters and a kernel size of 3.\n", - " self.conv2 = layers.Conv2D(64, kernel_size=3, activation=tf.nn.relu)\n", - " # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. \n", - " self.maxpool2 = layers.MaxPool2D(2, strides=2)\n", - "\n", - " # Flatten the data to a 1-D vector for the fully connected layer.\n", - " self.flatten = layers.Flatten()\n", - "\n", - " # Fully connected layer.\n", - " self.fc1 = layers.Dense(1024)\n", - " # Apply Dropout (if is_training is False, dropout is not applied).\n", - " self.dropout = layers.Dropout(rate=0.5)\n", - "\n", - " # Output layer, class prediction.\n", - " self.out = layers.Dense(num_classes)\n", - "\n", - " # Set forward pass.\n", - " def call(self, x, is_training=False):\n", - " x = tf.reshape(x, [-1, 28, 28, 1])\n", - " x = self.conv1(x)\n", - " x = self.maxpool1(x)\n", - " x = self.conv2(x)\n", - " x = self.maxpool2(x)\n", - " x = self.flatten(x)\n", - " x = self.fc1(x)\n", - " x = self.dropout(x, training=is_training)\n", - " x = self.out(x)\n", - " if not is_training:\n", - " # tf cross entropy expect logits without softmax, so only\n", - " # apply softmax when not training.\n", - " x = tf.nn.softmax(x)\n", - " return x\n", - "\n", - "# Build neural network model.\n", - "conv_net = ConvNet()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "# Cross-Entropy Loss.\n", - "# Note that this will apply 'softmax' to the logits.\n", - "def cross_entropy_loss(x, y):\n", - " # Convert labels to int 64 for tf cross-entropy function.\n", - " y = tf.cast(y, tf.int64)\n", - " # Apply softmax to logits and compute cross-entropy.\n", - " loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)\n", - " # Average loss across the batch.\n", - " return tf.reduce_mean(loss)\n", - "\n", - "# Accuracy metric.\n", - "def accuracy(y_pred, y_true):\n", - " # Predicted class is the index of highest score in prediction vector (i.e. argmax).\n", - " correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))\n", - " return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1)\n", - "\n", - "# Stochastic gradient descent optimizer.\n", - "optimizer = tf.optimizers.Adam(learning_rate)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# Optimization process. \n", - "def run_optimization(x, y):\n", - " # Wrap computation inside a GradientTape for automatic differentiation.\n", - " with tf.GradientTape() as g:\n", - " # Forward pass.\n", - " pred = conv_net(x, is_training=True)\n", - " # Compute loss.\n", - " loss = cross_entropy_loss(pred, y)\n", - " \n", - " # Variables to update, i.e. trainable variables.\n", - " trainable_variables = conv_net.trainable_variables\n", - "\n", - " # Compute gradients.\n", - " gradients = g.gradient(loss, trainable_variables)\n", - " \n", - " # Update W and b following gradients.\n", - " optimizer.apply_gradients(zip(gradients, trainable_variables))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "step: 10, loss: 1.877234, accuracy: 0.789062\n", - "step: 20, loss: 1.609390, accuracy: 0.898438\n", - "step: 30, loss: 1.555458, accuracy: 0.960938\n", - "step: 40, loss: 1.588296, accuracy: 0.921875\n", - "step: 50, loss: 1.561057, accuracy: 0.929688\n", - "step: 60, loss: 1.539851, accuracy: 0.945312\n", - "step: 70, loss: 1.527458, accuracy: 0.976562\n", - "step: 80, loss: 1.526701, accuracy: 0.945312\n", - "step: 90, loss: 1.522610, accuracy: 0.968750\n", - "step: 100, loss: 1.514970, accuracy: 0.968750\n", - "step: 110, loss: 1.489902, accuracy: 0.976562\n", - "step: 120, loss: 1.520697, accuracy: 0.953125\n", - "step: 130, loss: 1.494326, accuracy: 0.968750\n", - "step: 140, loss: 1.501781, accuracy: 0.984375\n", - "step: 150, loss: 1.506588, accuracy: 0.976562\n", - "step: 160, loss: 1.493378, accuracy: 0.984375\n", - "step: 170, loss: 1.494006, accuracy: 0.984375\n", - "step: 180, loss: 1.509782, accuracy: 0.953125\n", - "step: 190, loss: 1.516123, accuracy: 0.953125\n", - "step: 200, loss: 1.515508, accuracy: 0.953125\n" - ] - } - ], - "source": [ - "# Run training for the given number of steps.\n", - "for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):\n", - " # Run the optimization to update W and b values.\n", - " run_optimization(batch_x, batch_y)\n", - " \n", - " if step % display_step == 0:\n", - " pred = conv_net(batch_x)\n", - " loss = cross_entropy_loss(pred, batch_y)\n", - " acc = accuracy(pred, batch_y)\n", - " print(\"step: %i, loss: %f, accuracy: %f\" % (step, loss, acc))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "I5_S5gwhZ1nD" + }, + "source": [ + "# Convolutional Neural Network Example\n", + "\n", + "Build a convolutional neural network with TensorFlow v2.\n", + "\n", + "This example is using a low-level approach to better understand all mechanics behind building convolutional neural networks and the training process.\n", + "\n", + "- Author: Aymeric Damien\n", + "- Project: https://github.com/aymericdamien/TensorFlow-Examples/" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Test Accuracy: 0.977700\n" - ] - } - ], - "source": [ - "# Test model on validation set.\n", - "pred = conv_net(x_test)\n", - "print(\"Test Accuracy: %f\" % accuracy(pred, y_test))" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "# Visualize predictions.\n", - "import matplotlib.pyplot as plt" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "YTK3ftGYZ1nZ" + }, + "source": [ + "## CNN Overview\n", + "\n", + "![CNN](http://personal.ie.cuhk.edu.hk/~ccloy/project_target_code/images/fig3.png)\n", + "\n", + "## MNIST Dataset Overview\n", + "\n", + "This example is using MNIST handwritten digits. The dataset contains 60,000 examples for training and 10,000 examples for testing. The digits have been size-normalized and centered in a fixed-size image (28x28 pixels) with values from 0 to 255. \n", + "\n", + "In this example, each image will be converted to float32 and normalized to [0, 1].\n", + "\n", + "![MNIST Dataset](http://neuralnetworksanddeeplearning.com/images/mnist_100_digits.png)\n", + "\n", + "More info: http://yann.lecun.com/exdb/mnist/" + ] + }, { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADQNJREFUeJzt3W+MVfWdx/HPZylNjPQBWLHEgnQb3bgaAzoaE3AzamxYbYKN1NQHGzbZMH2AZps0ZA1PypMmjemfrU9IpikpJtSWhFbRGBeDGylRGwejBYpQICzMgkAzJgUT0yDfPphDO8W5v3u5/84dv+9XQube8z1/vrnhM+ecOefcnyNCAPL5h7obAFAPwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+IKnP9HNjtrmdEOixiHAr83W057e9wvZB24dtP9nJugD0l9u9t9/2LEmHJD0gaVzSW5Iei4jfF5Zhzw/0WD/2/HdJOhwRRyPiz5J+IWllB+sD0EedhP96SSemvB+vpv0d2yO2x2yPdbAtAF3WyR/8pju0+MRhfUSMShqVOOwHBkkne/5xSQunvP+ipJOdtQOgXzoJ/1uSbrT9JduflfQNSdu70xaAXmv7sD8iLth+XNL/SJolaVNE7O9aZwB6qu1LfW1tjHN+oOf6cpMPgJmL8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaTaHqJbkmwfk3RO0seSLkTEUDeaAtB7HYW/cm9E/LEL6wHQRxz2A0l1Gv6QtMP2Htsj3WgIQH90eti/LCJO2p4v6RXb70XErqkzVL8U+MUADBhHRHdWZG+QdD4ivl+YpzsbA9BQRLiV+do+7Ld9te3PXXot6SuS9rW7PgD91clh/3WSfm370np+HhEvd6UrAD3XtcP+ljbGYT/Qcz0/7AcwsxF+ICnCDyRF+IGkCD+QFOEHkurGU30prFq1qmFtzZo1xWVPnjxZrH/00UfF+pYtW4r1999/v2Ht8OHDxWWRF3t+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iKR3pbdPTo0Ya1xYsX96+RaZw7d65hbf/+/X3sZLCMj483rD311FPFZcfGxrrdTt/wSC+AIsIPJEX4gaQIP5AU4QeSIvxAUoQfSIrn+VtUemb/tttuKy574MCBYv3mm28u1m+//fZifXh4uGHt7rvvLi574sSJYn3hwoXFeicuXLhQrJ89e7ZYX7BgQdvbPn78eLE+k6/zt4o9P5AU4QeSIvxAUoQfSIrwA0kRfiApwg8k1fR5ftubJH1V0pmIuLWaNk/SLyUtlnRM0qMR8UHTjc3g5/kH2dy5cxvWlixZUlx2z549xfqdd97ZVk+taDZewaFDh4r1ZvdPzJs3r2Ft7dq1xWU3btxYrA+ybj7P/zNJKy6b9qSknRFxo6Sd1XsAM0jT8EfELkkTl01eKWlz9XqzpIe73BeAHmv3nP+6iDglSdXP+d1rCUA/9PzeftsjkkZ6vR0AV6bdPf9p2wskqfp5ptGMETEaEUMRMdTmtgD0QLvh3y5pdfV6taTnu9MOgH5pGn7bz0p6Q9I/2R63/R+SvifpAdt/kPRA9R7ADML39mNgPfLII8X61q1bi/V9+/Y1rN17773FZScmLr/ANXPwvf0Aigg/kBThB5Ii/EBShB9IivADSXGpD7WZP7/8SMjevXs7Wn7VqlUNa9u2bSsuO5NxqQ9AEeEHkiL8QFKEH0iK8ANJEX4gKcIPJMUQ3ahNs6/Pvvbaa4v1Dz4of1v8wYMHr7inTNjzA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBSPM+Pnlq2bFnD2quvvlpcdvbs2cX68PBwsb5r165i/dOK5/kBFBF+ICnCDyRF+IGkCD+QFOEHkiL8QFJNn+e3vUnSVyWdiYhbq2kbJK2RdLaabX1EvNSrJjFzPfjggw1rza7j79y5s1h/44032uoJk1rZ8/9M0opppv8oIpZU/wg+MMM0DX9E7JI00YdeAPRRJ+f8j9v+ne1Ntud2rSMAfdFu+DdK+rKkJZJOSfpBoxltj9gesz3W5rYA9EBb4Y+I0xHxcURclPQTSXcV5h2NiKGIGGq3SQDd11b4bS+Y8vZrkvZ1px0A/dLKpb5nJQ1L+rztcUnfkTRse4mkkHRM0jd72COAHuB5fnTkqquuKtZ3797dsHbLLbcUl73vvvuK9ddff71Yz4rn+QEUEX4gKcIPJEX4gaQIP5AU4QeSYohudGTdunXF+tKlSxvWXn755eKyXMrrLfb8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AUj/Si6KGHHirWn3vuuWL9ww8/bFhbsWK6L4X+mzfffLNYx/R4pBdAEeEHkiL8QFKEH0iK8ANJEX4gKcIPJMXz/Mldc801xfrTTz9drM+aNatYf+mlxgM4cx2/Xuz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpps/z214o6RlJX5B0UdJoRPzY9jxJv5S0WNIxSY9GxAdN1sXz/H3W7Dp8s2vtd9xxR7F+5MiRYr30zH6zZdGebj7Pf0HStyPiZkl3S1pr+58lPSlpZ0TcKGln9R7ADNE0/BFxKiLerl6fk3RA0vWSVkraXM22WdLDvWoSQPdd0Tm/7cWSlkr6raTrIuKUNPkLQtL8bjcHoHdavrff9hxJ2yR9KyL+ZLd0WiHbI5JG2msPQK+0tOe3PVuTwd8SEb+qJp+2vaCqL5B0ZrplI2I0IoYiYqgbDQPojqbh9+Qu/qeSDkTED6eUtktaXb1eLen57rcHoFdaudS3XNJvJO3V5KU+SVqvyfP+rZIWSTou6esRMdFkXVzq67ObbrqpWH/vvfc6Wv/KlSuL9RdeeKGj9ePKtXqpr+k5f0TsltRoZfdfSVMABgd3+AFJEX4gKcIPJEX4gaQIP5AU4QeS4qu7PwVuuOGGhrUdO3Z0tO5169YV6y+++GJH60d92PMDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFJc5/8UGBlp/C1pixYt6mjdr732WrHe7PsgMLjY8wNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUlznnwGWL19erD/xxBN96gSfJuz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpptf5bS+U9IykL0i6KGk0In5se4OkNZLOVrOuj4iXetVoZvfcc0+xPmfOnLbXfeTIkWL9/Pnzba8bg62Vm3wuSPp2RLxt+3OS9th+par9KCK+37v2APRK0/BHxClJp6rX52wfkHR9rxsD0FtXdM5ve7GkpZJ+W0163PbvbG+yPbfBMiO2x2yPddQpgK5qOfy250jaJulbEfEnSRslfVnSEk0eGfxguuUiYjQihiJiqAv9AuiSlsJve7Ymg78lIn4lSRFxOiI+joiLkn4i6a7etQmg25qG37Yl/VTSgYj44ZTpC6bM9jVJ+7rfHoBeaeWv/csk/Zukvbbfqaatl/SY7SWSQtIxSd/sSYfoyLvvvlus33///cX6xMREN9vBAGnlr/27JXmaEtf0gRmMO/yApAg/kBThB5Ii/EBShB9IivADSbmfQyzbZjxnoMciYrpL85/Anh9IivADSRF+ICnCDyRF+IGkCD+QFOEHkur3EN1/lPR/U95/vpo2iAa1t0HtS6K3dnWztxtanbGvN/l8YuP22KB+t9+g9jaofUn01q66euOwH0iK8ANJ1R3+0Zq3XzKovQ1qXxK9tauW3mo95wdQn7r3/ABqUkv4ba+wfdD2YdtP1tFDI7aP2d5r+526hxirhkE7Y3vflGnzbL9i+w/Vz2mHSauptw22/7/67N6x/WBNvS20/b+2D9jeb/s/q+m1fnaFvmr53Pp+2G97lqRDkh6QNC7pLUmPRcTv+9pIA7aPSRqKiNqvCdv+F0nnJT0TEbdW056SNBER36t+cc6NiP8akN42SDpf98jN1YAyC6aOLC3pYUn/rho/u0Jfj6qGz62OPf9dkg5HxNGI+LOkX0haWUMfAy8idkm6fNSMlZI2V683a/I/T9816G0gRMSpiHi7en1O0qWRpWv97Ap91aKO8F8v6cSU9+MarCG/Q9IO23tsj9TdzDSuq4ZNvzR8+vya+7lc05Gb++mykaUH5rNrZ8Trbqsj/NN9xdAgXXJYFhG3S/pXSWurw1u0pqWRm/tlmpGlB0K7I153Wx3hH5e0cMr7L0o6WUMf04qIk9XPM5J+rcEbffj0pUFSq59nau7nrwZp5ObpRpbWAHx2gzTidR3hf0vSjba/ZPuzkr4haXsNfXyC7aurP8TI9tWSvqLBG314u6TV1evVkp6vsZe/MygjNzcaWVo1f3aDNuJ1LTf5VJcy/lvSLEmbIuK7fW9iGrb/UZN7e2nyicef19mb7WclDWvyqa/Tkr4j6TlJWyUtknRc0tcjou9/eGvQ27AmD13/OnLzpXPsPve2XNJvJO2VdLGavF6T59e1fXaFvh5TDZ8bd/gBSXGHH5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpP4CIJjqosJxHysAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "zL08AkXuZ1nb" + }, + "outputs": [], + "source": [ + "from __future__ import absolute_import, division, print_function\n", + "\n", + "import tensorflow as tf\n", + "from tensorflow.keras import Model, layers\n", + "import numpy as np" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model prediction: 7\n" - ] + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "72ipgy3nZ1nh" + }, + "outputs": [], + "source": [ + "# MNIST dataset parameters.\n", + "num_classes = 10 # total classes (0-9 digits).\n", + "\n", + "# Training parameters.\n", + "learning_rate = 0.001\n", + "training_steps = 200\n", + "batch_size = 128\n", + "display_step = 10\n", + "\n", + "# Network parameters.\n", + "conv1_filters = 32 # number of filters for 1st conv layer.\n", + "conv2_filters = 64 # number of filters for 2nd conv layer.\n", + "fc1_units = 1024 # number of neurons for 1st fully-connected layer." + ] }, { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADYNJREFUeJzt3X+oXPWZx/HPZ20CYouaFLMXYzc16rIqauUqiy2LSzW6S0wMWE3wjyy77O0fFbYYfxGECEuwLNvu7l+BFC9NtLVpuDHGWjYtsmoWTPAqGk2TtkauaTbX3A0pNkGkJnn2j3uy3MY7ZyYzZ+bMzfN+QZiZ88w552HI555z5pw5X0eEAOTzJ3U3AKAehB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKf6+XKbHM5IdBlEeFW3tfRlt/2nbZ/Zfs92491siwAveV2r+23fZ6kX0u6XdJBSa9LWhERvyyZhy0/0GW92PLfLOm9iHg/Iv4g6ceSlnawPAA91En4L5X02ymvDxbT/ojtIdujtkc7WBeAinXyhd90uxaf2a2PiPWS1kvs9gP9pJMt/0FJl015PV/Soc7aAdArnYT/dUlX2v6y7dmSlkvaVk1bALqt7d3+iDhh+wFJ2yWdJ2k4IvZU1hmArmr7VF9bK+OYH+i6nlzkA2DmIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+IKme3rob7XnooYdK6+eff37D2nXXXVc67z333NNWT6etW7eutP7aa681rD399NMdrRudYcsPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0lx994+sGnTptJ6p+fi67R///6Gtdtuu6103gMHDlTdTgrcvRdAKcIPJEX4gaQIP5AU4QeSIvxAUoQfSKqj3/PbHpN0TNJJSSciYrCKps41dZ7H37dvX2l9+/btpfXLL7+8tH7XXXeV1hcuXNiwdv/995fO++STT5bW0Zkqbubx1xFxpILlAOghdvuBpDoNf0j6ue03bA9V0RCA3uh0t/+rEXHI9iWSfmF7X0S8OvUNxR8F/jAAfaajLX9EHCoeJyQ9J+nmad6zPiIG+TIQ6C9th9/2Bba/cPq5pEWS3q2qMQDd1clu/zxJz9k+vZwfRcR/VtIVgK5rO/wR8b6k6yvsZcYaHCw/olm2bFlHy9+zZ09pfcmSJQ1rR46Un4U9fvx4aX327Nml9Z07d5bWr7++8X+RuXPnls6L7uJUH5AU4QeSIvxAUoQfSIrwA0kRfiAphuiuwMDAQGm9uBaioWan8u64447S+vj4eGm9E6tWrSqtX3311W0v+8UXX2x7XnSOLT+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJMV5/gq88MILpfUrrriitH7s2LHS+tGjR8+6p6osX768tD5r1qwedYKqseUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQ4z98DH3zwQd0tNPTwww+X1q+66qqOlr9r1662aug+tvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kJQjovwN9rCkxZImIuLaYtocSZskLZA0JuneiPhd05XZ5StD5RYvXlxa37x5c2m92RDdExMTpfWy+wG88sorpfOiPRFRPlBEoZUt/w8k3XnGtMckvRQRV0p6qXgNYAZpGv6IeFXSmbeSWSppQ/F8g6S7K+4LQJe1e8w/LyLGJal4vKS6lgD0Qtev7bc9JGmo2+sBcHba3fIftj0gScVjw299ImJ9RAxGxGCb6wLQBe2Gf5uklcXzlZKer6YdAL3SNPy2n5X0mqQ/t33Q9j9I+o6k223/RtLtxWsAM0jTY/6IWNGg9PWKe0EXDA6WH201O4/fzKZNm0rrnMvvX1zhByRF+IGkCD+QFOEHkiL8QFKEH0iKW3efA7Zu3dqwtmjRoo6WvXHjxtL6448/3tHyUR+2/EBShB9IivADSRF+ICnCDyRF+IGkCD+QVNNbd1e6Mm7d3ZaBgYHS+ttvv92wNnfu3NJ5jxw5Ulq/5ZZbSuv79+8vraP3qrx1N4BzEOEHkiL8QFKEH0iK8ANJEX4gKcIPJMXv+WeAkZGR0nqzc/llnnnmmdI65/HPXWz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCppuf5bQ9LWixpIiKuLaY9IekfJf1v8bbVEfGzbjV5rluyZElp/cYbb2x72S+//HJpfc2aNW0vGzNbK1v+H0i6c5rp/xYRNxT/CD4wwzQNf0S8KuloD3oB0EOdHPM/YHu37WHbF1fWEYCeaDf86yQtlHSDpHFJ3230RttDtkdtj7a5LgBd0Fb4I+JwRJyMiFOSvi/p5pL3ro+IwYgYbLdJANVrK/y2p95Odpmkd6tpB0CvtHKq71lJt0r6ou2DktZIutX2DZJC0pikb3axRwBd0DT8EbFimslPdaGXc1az39uvXr26tD5r1qy21/3WW2+V1o8fP972sjGzcYUfkBThB5Ii/EBShB9IivADSRF+IClu3d0Dq1atKq3fdNNNHS1/69atDWv8ZBeNsOUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQcEb1bmd27lfWRTz75pLTeyU92JWn+/PkNa+Pj4x0tGzNPRLiV97HlB5Ii/EBShB9IivADSRF+ICnCDyRF+IGk+D3/OWDOnDkNa59++mkPO/msjz76qGGtWW/Nrn+48MIL2+pJki666KLS+oMPPtj2sltx8uTJhrVHH320dN6PP/64kh7Y8gNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUk3P89u+TNJGSX8q6ZSk9RHxH7bnSNokaYGkMUn3RsTvutcqGtm9e3fdLTS0efPmhrVm9xqYN29eaf2+++5rq6d+9+GHH5bW165dW8l6Wtnyn5C0KiL+QtJfSvqW7aslPSbppYi4UtJLxWsAM0TT8EfEeES8WTw/JmmvpEslLZW0oXjbBkl3d6tJANU7q2N+2wskfUXSLknzImJcmvwDIemSqpsD0D0tX9tv+/OSRiR9OyJ+b7d0mzDZHpI01F57ALqlpS2/7VmaDP4PI2JLMfmw7YGiPiBpYrp5I2J9RAxGxGAVDQOoRtPwe3IT/5SkvRHxvSmlbZJWFs9XSnq++vYAdEvTW3fb/pqkHZLe0eSpPklarcnj/p9I+pKkA5K+ERFHmywr5a27t2zZUlpfunRpjzrJ5cSJEw1rp06dalhrxbZt20rro6OjbS97x44dpfWdO3eW1lu9dXfTY/6I+G9JjRb29VZWAqD/cIUfkBThB5Ii/EBShB9IivADSRF+ICmG6O4DjzzySGm90yG8y1xzzTWl9W7+bHZ4eLi0PjY21tHyR0ZGGtb27dvX0bL7GUN0AyhF+IGkCD+QFOEHkiL8QFKEH0iK8ANJcZ4fOMdwnh9AKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9Iqmn4bV9m+79s77W9x/Y/FdOfsP0/tt8q/v1t99sFUJWmN/OwPSBpICLetP0FSW9IulvSvZKOR8S/trwybuYBdF2rN/P4XAsLGpc0Xjw/ZnuvpEs7aw9A3c7qmN/2AklfkbSrmPSA7d22h21f3GCeIdujtkc76hRApVq+h5/tz0t6RdLaiNhie56kI5JC0j9r8tDg75ssg91+oMta3e1vKfy2Z0n6qaTtEfG9aeoLJP00Iq5tshzCD3RZZTfwtG1JT0naOzX4xReBpy2T9O7ZNgmgPq182/81STskvSPpVDF5taQVkm7Q5G7/mKRvFl8Oli2LLT/QZZXu9leF8APdx337AZQi/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJNX0Bp4VOyLpgymvv1hM60f92lu/9iXRW7uq7O3PWn1jT3/P/5mV26MRMVhbAyX6tbd+7Uuit3bV1Ru7/UBShB9Iqu7wr695/WX6tbd+7Uuit3bV0lutx/wA6lP3lh9ATWoJv+07bf/K9nu2H6ujh0Zsj9l+pxh5uNYhxoph0CZsvztl2hzbv7D9m+Jx2mHSauqtL0ZuLhlZutbPrt9GvO75br/t8yT9WtLtkg5Kel3Sioj4ZU8bacD2mKTBiKj9nLDtv5J0XNLG06Mh2f4XSUcj4jvFH86LI+LRPuntCZ3lyM1d6q3RyNJ/pxo/uypHvK5CHVv+myW9FxHvR8QfJP1Y0tIa+uh7EfGqpKNnTF4qaUPxfIMm//P0XIPe+kJEjEfEm8XzY5JOjyxd62dX0lct6gj/pZJ+O+X1QfXXkN8h6ee237A9VHcz05h3emSk4vGSmvs5U9ORm3vpjJGl++aza2fE66rVEf7pRhPpp1MOX42IGyX9jaRvFbu3aM06SQs1OYzbuKTv1tlMMbL0iKRvR8Tv6+xlqmn6quVzqyP8ByVdNuX1fEmHauhjWhFxqHickPScJg9T+snh04OkFo8TNffz/yLicEScjIhTkr6vGj+7YmTpEUk/jIgtxeTaP7vp+qrrc6sj/K9LutL2l23PlrRc0rYa+vgM2xcUX8TI9gWSFqn/Rh/eJmll8XylpOdr7OWP9MvIzY1GllbNn12/jXhdy0U+xamMf5d0nqThiFjb8yamYftyTW7tpclfPP6ozt5sPyvpVk3+6uuwpDWStkr6iaQvSTog6RsR0fMv3hr0dqvOcuTmLvXWaGTpXarxs6tyxOtK+uEKPyAnrvADkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5DU/wG6SwYLYCwMKQAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Is91vr4NZ1nj", + "outputId": "ce69c031-e819-4db5-fe6c-214094e6ae2e" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n", + "11493376/11490434 [==============================] - 0s 0us/step\n", + "11501568/11490434 [==============================] - 0s 0us/step\n" + ] + } + ], + "source": [ + "# Prepare MNIST data.\n", + "from tensorflow.keras.datasets import mnist\n", + "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", + "# Convert to float32.\n", + "x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)\n", + "# Normalize images value from [0, 255] to [0, 1].\n", + "x_train, x_test = x_train / 255., x_test / 255." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model prediction: 2\n" - ] + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "bd6I640IZ1nk" + }, + "outputs": [], + "source": [ + "# Use tf.data API to shuffle and batch data.\n", + "train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", + "train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)" + ] }, { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADCFJREFUeJzt3WGoXPWZx/Hvs1n7wrQvDDUarGu6RVdLxGS5iBBZXarFFSHmRaUKS2RL0xcNWNgXK76psBREtt1dfFFIaWgqrbVEs2pdbYsspguLGjVU21grcre9a8hVFGoVKSbPvrgn5VbvnLmZOTNnkuf7gTAz55kz52HI7/7PzDlz/pGZSKrnz/puQFI/DL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paL+fJobiwhPJ5QmLDNjNc8ba+SPiOsi4lcR8UpE3D7Oa0marhj13P6IWAO8DFwLLADPADdn5i9b1nHklyZsGiP/5cArmflqZv4B+AGwbYzXkzRF44T/POC3yx4vNMv+RETsjIiDEXFwjG1J6tg4X/ittGvxod36zNwN7AZ3+6VZMs7IvwCcv+zxJ4DXxmtH0rSME/5ngAsj4pMR8RHg88DD3bQladJG3u3PzPcjYhfwY2ANsCczf9FZZ5ImauRDfSNtzM/80sRN5SQfSacuwy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKmuoU3arnoosuGlh76aWXWte97bbbWuv33HPPSD1piSO/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxU11nH+iJgH3gaOAe9n5lwXTen0sWXLloG148ePt667sLDQdTtapouTfP42M9/o4HUkTZG7/VJR44Y/gZ9ExLMRsbOLhiRNx7i7/Vsz87WIWA/8NCJeyswDy5/Q/FHwD4M0Y8Ya+TPzteZ2EdgPXL7Cc3Zn5pxfBkqzZeTwR8TaiPjYifvAZ4EXu2pM0mSNs9t/DrA/Ik68zvcz8/FOupI0cSOHPzNfBS7rsBedhjZv3jyw9s4777Suu3///q7b0TIe6pOKMvxSUYZfKsrwS0UZfqkowy8V5aW7NZZNmza11nft2jWwdu+993bdjk6CI79UlOGXijL8UlGGXyrK8EtFGX6pKMMvFeVxfo3l4osvbq2vXbt2YO3+++/vuh2dBEd+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyoqMnN6G4uY3sY0FU8//XRr/eyzzx5YG3YtgGGX9tbKMjNW8zxHfqkowy8VZfilogy/VJThl4oy/FJRhl8qaujv+SNiD3ADsJiZm5pl64D7gY3APHBTZr41uTbVl40bN7bW5+bmWusvv/zywJrH8fu1mpH/O8B1H1h2O/BEZl4IPNE8lnQKGRr+zDwAvPmBxduAvc39vcCNHfclacJG/cx/TmYeAWhu13fXkqRpmPg1/CJiJ7Bz0tuRdHJGHfmPRsQGgOZ2cdATM3N3Zs5lZvs3Q5KmatTwPwzsaO7vAB7qph1J0zI0/BFxH/A/wF9FxEJEfAG4C7g2In4NXNs8lnQKGfqZPzNvHlD6TMe9aAZdddVVY63/+uuvd9SJuuYZflJRhl8qyvBLRRl+qSjDLxVl+KWinKJbrS699NKx1r/77rs76kRdc+SXijL8UlGGXyrK8EtFGX6pKMMvFWX4paKcoru4K664orX+6KOPttbn5+db61u3bh1Ye++991rX1WicoltSK8MvFWX4paIMv1SU4ZeKMvxSUYZfKsrf8xd3zTXXtNbXrVvXWn/88cdb6x7Ln12O/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9U1NDj/BGxB7gBWMzMTc2yO4EvAifmX74jM/9zUk1qci677LLW+rDrPezbt6/LdjRFqxn5vwNct8Lyf83Mzc0/gy+dYoaGPzMPAG9OoRdJUzTOZ/5dEfHziNgTEWd11pGkqRg1/N8EPgVsBo4AXx/0xIjYGREHI+LgiNuSNAEjhT8zj2bmscw8DnwLuLzlubszcy4z50ZtUlL3Rgp/RGxY9nA78GI37UialtUc6rsPuBr4eEQsAF8Fro6IzUAC88CXJtijpAnwuv2nuXPPPbe1fujQodb6W2+91Vq/5JJLTronTZbX7ZfUyvBLRRl+qSjDLxVl+KWiDL9UlJfuPs3deuutrfX169e31h977LEOu9EsceSXijL8UlGGXyrK8EtFGX6pKMMvFWX4paI8zn+au+CCC8Zaf9hPenXqcuSXijL8UlGGXyrK8EtFGX6pKMMvFWX4paI8zn+au+GGG8Za/5FHHumoE80aR36pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKmrocf6IOB/4LnAucBzYnZn/HhHrgPuBjcA8cFNm+uPvHlx55ZUDa8Om6FZdqxn53wf+MTMvAa4AvhwRnwZuB57IzAuBJ5rHkk4RQ8OfmUcy87nm/tvAYeA8YBuwt3naXuDGSTUpqXsn9Zk/IjYCW4CngHMy8wgs/YEA2ud9kjRTVn1uf0R8FHgA+Epm/i4iVrveTmDnaO1JmpRVjfwRcQZLwf9eZj7YLD4aERua+gZgcaV1M3N3Zs5l5lwXDUvqxtDwx9IQ/23gcGZ+Y1npYWBHc38H8FD37UmalNXs9m8F/h54ISIONcvuAO4CfhgRXwB+A3xuMi1qmO3btw+srVmzpnXd559/vrV+4MCBkXrS7Bsa/sz8b2DQB/zPdNuOpGnxDD+pKMMvFWX4paIMv1SU4ZeKMvxSUV66+xRw5plnttavv/76kV973759rfVjx46N/NqabY78UlGGXyrK8EtFGX6pKMMvFWX4paIMv1RUZOb0NhYxvY2dRs4444zW+pNPPjmwtri44gWW/uiWW25prb/77rutdc2ezFzVNfYc+aWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKI/zS6cZj/NLamX4paIMv1SU4ZeKMvxSUYZfKsrwS0UNDX9EnB8R/xURhyPiFxFxW7P8zoj4v4g41Pwb/eLxkqZu6Ek+EbEB2JCZz0XEx4BngRuBm4DfZ+a/rHpjnuQjTdxqT/IZOmNPZh4BjjT3346Iw8B547UnqW8n9Zk/IjYCW4CnmkW7IuLnEbEnIs4asM7OiDgYEQfH6lRSp1Z9bn9EfBR4EvhaZj4YEecAbwAJ/DNLHw3+YchruNsvTdhqd/tXFf6IOAP4EfDjzPzGCvWNwI8yc9OQ1zH80oR19sOeiAjg28Dh5cFvvgg8YTvw4sk2Kak/q/m2/0rgZ8ALwPFm8R3AzcBmlnb754EvNV8Otr2WI780YZ3u9nfF8EuT5+/5JbUy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFTX0Ap4dewP432WPP94sm0Wz2tus9gX2Nqoue7tgtU+c6u/5P7TxiIOZOddbAy1mtbdZ7QvsbVR99eZuv1SU4ZeK6jv8u3vefptZ7W1W+wJ7G1UvvfX6mV9Sf/oe+SX1pJfwR8R1EfGriHglIm7vo4dBImI+Il5oZh7udYqxZhq0xYh4cdmydRHx04j4dXO74jRpPfU2EzM3t8ws3et7N2szXk99tz8i1gAvA9cCC8AzwM2Z+cupNjJARMwDc5nZ+zHhiPgb4PfAd0/MhhQRdwNvZuZdzR/OszLzn2aktzs5yZmbJ9TboJmlb6XH967LGa+70MfIfznwSma+mpl/AH4AbOuhj5mXmQeANz+weBuwt7m/l6X/PFM3oLeZkJlHMvO55v7bwImZpXt971r66kUf4T8P+O2yxwvM1pTfCfwkIp6NiJ19N7OCc07MjNTcru+5nw8aOnPzNH1gZumZee9GmfG6a32Ef6XZRGbpkMPWzPxr4O+ALze7t1qdbwKfYmkatyPA1/tspplZ+gHgK5n5uz57WW6Fvnp53/oI/wJw/rLHnwBe66GPFWXma83tIrCfpY8ps+ToiUlSm9vFnvv5o8w8mpnHMvM48C16fO+amaUfAL6XmQ82i3t/71bqq6/3rY/wPwNcGBGfjIiPAJ8HHu6hjw+JiLXNFzFExFrgs8ze7MMPAzua+zuAh3rs5U/MyszNg2aWpuf3btZmvO7lJJ/mUMa/AWuAPZn5tak3sYKI+EuWRntY+sXj9/vsLSLuA65m6VdfR4GvAv8B/BD4C+A3wOcyc+pfvA3o7WpOcubmCfU2aGbpp+jxvetyxutO+vEMP6kmz/CTijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1TU/wNPnZK3k8+kHgAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "5YA3y997Z1no" + }, + "outputs": [], + "source": [ + "# Create TF Model.\n", + "class ConvNet(Model):\n", + " # Set layers.\n", + " def __init__(self):\n", + " super(ConvNet, self).__init__()\n", + " # Convolution Layer with 32 filters and a kernel size of 5.\n", + " self.conv1 = layers.Conv2D(32, kernel_size=5, activation=tf.nn.relu)\n", + " # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. \n", + " self.maxpool1 = layers.MaxPool2D(2, strides=2)\n", + "\n", + " # Convolution Layer with 64 filters and a kernel size of 3.\n", + " self.conv2 = layers.Conv2D(64, kernel_size=3, activation=tf.nn.relu)\n", + " # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. \n", + " self.maxpool2 = layers.MaxPool2D(2, strides=2)\n", + " \n", + " # Convolution Layer with 64 filters and a kernel size of 3.\n", + " self.conv2 = layers.Conv2D(64, kernel_size=3, activation=tf.nn.relu)\n", + " # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. \n", + " self.maxpool2 = layers.MaxPool2D(2, strides=2)\n", + "\n", + " # Flatten the data to a 1-D vector for the fully connected layer.\n", + " self.flatten = layers.Flatten()\n", + "\n", + " # Fully connected layer.\n", + " self.fc1 = layers.Dense(1024)\n", + " # Apply Dropout (if is_training is False, dropout is not applied).\n", + " self.dropout = layers.Dropout(rate=0.5)\n", + "\n", + " # Output layer, class prediction.\n", + " self.out = layers.Dense(num_classes)\n", + "\n", + " # Set forward pass.\n", + " def call(self, x, is_training=False):\n", + " x = tf.reshape(x, [-1, 28, 28, 1])\n", + " x = self.conv1(x)\n", + " x = self.maxpool1(x)\n", + " x = self.conv2(x)\n", + " x = self.maxpool2(x)\n", + " x = self.flatten(x)\n", + " x = self.fc1(x)\n", + " x = self.dropout(x, training=is_training)\n", + " x = self.out(x)\n", + " if not is_training:\n", + " # tf cross entropy expect logits without softmax, so only\n", + " # apply softmax when not training.\n", + " x = tf.nn.softmax(x)\n", + " return x\n", + "\n", + "# Build neural network model.\n", + "conv_net = ConvNet()" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model prediction: 1\n" - ] + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "Bh7K3VrlZ1nt" + }, + "outputs": [], + "source": [ + "# Cross-Entropy Loss.\n", + "# Note that this will apply 'softmax' to the logits.\n", + "def cross_entropy_loss(x, y):\n", + " # Convert labels to int 64 for tf cross-entropy function.\n", + " y = tf.cast(y, tf.int64)\n", + " # Apply softmax to logits and compute cross-entropy.\n", + " loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)\n", + " # Average loss across the batch.\n", + " return tf.reduce_mean(loss)\n", + "\n", + "# Accuracy metric.\n", + "def accuracy(y_pred, y_true):\n", + " # Predicted class is the index of highest score in prediction vector (i.e. argmax).\n", + " correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))\n", + " return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1)\n", + "\n", + "# Stochastic gradient descent optimizer.\n", + "optimizer = tf.optimizers.Adam(learning_rate)" + ] }, { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADbVJREFUeJzt3W2IXPUVx/HfSWzfpH2hZE3jU9I2EitCTVljoRKtxZKUStIX0YhIiqUbJRoLfVFJwEaKINqmLRgSthi6BbUK0bqE0KaINBWCuJFaNVtblTVNs2yMEWsI0picvti7siY7/zuZuU+b8/2AzMOZuXO8+tt7Z/733r+5uwDEM6PuBgDUg/ADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwjqnCo/zMw4nBAombtbO6/rastvZkvN7A0ze9PM7u1mWQCqZZ0e229mMyX9U9INkg5IeknSLe6+L/EetvxAyarY8i+W9Ka7v+3u/5P0e0nLu1gegAp1E/4LJf170uMD2XOfYmZ9ZjZkZkNdfBaAgnXzg99Uuxan7da7e7+kfondfqBJutnyH5B08aTHF0k62F07AKrSTfhfknSpmX3RzD4raZWkwWLaAlC2jnf73f1jM7tL0p8kzZS0zd1fL6wzAKXqeKivow/jOz9QukoO8gEwfRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EFSlU3SjerNmzUrWH3744WR9zZo1yfrevXuT9ZUrV7asvfPOO8n3olxs+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqK5m6TWzEUkfSjoh6WN37815PbP0VmzBggXJ+vDwcFfLnzEjvf1Yt25dy9rmzZu7+mxMrd1Zeos4yOeb7n64gOUAqBC7/UBQ3YbfJe0ys71m1ldEQwCq0e1u/zfc/aCZnS/pz2b2D3ffPfkF2R8F/jAADdPVlt/dD2a3hyQ9I2nxFK/pd/fevB8DAVSr4/Cb2Swz+/zEfUnflvRaUY0BKFc3u/1zJD1jZhPLedzd/1hIVwBK13H43f1tSV8tsBd0qKenp2VtYGCgwk4wnTDUBwRF+IGgCD8QFOEHgiL8QFCEHwiKS3dPA6nTYiVpxYoVLWuLF5920GWllixZ0rKWdzrwK6+8kqzv3r07WUcaW34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCKqrS3ef8Ydx6e6OnDhxIlk/efJkRZ2cLm+svpve8qbwvvnmm5P1vOnDz1btXrqbLT8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBMU4fwPs3LkzWV+2bFmyXuc4/3vvvZesHz16tGVt3rx5RbfzKTNnzix1+U3FOD+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCCr3uv1mtk3SdyUdcvcrsufOk/SkpPmSRiTd5O7vl9fm9Hbttdcm6wsXLkzW88bxyxzn37p1a7K+a9euZP2DDz5oWbv++uuT792wYUOynufOO+9sWduyZUtXyz4btLPl/62kpac8d6+k59z9UknPZY8BTCO54Xf33ZKOnPL0ckkD2f0BSa2njAHQSJ1+55/j7qOSlN2eX1xLAKpQ+lx9ZtYnqa/szwFwZjrd8o+Z2VxJym4PtXqhu/e7e6+793b4WQBK0Gn4ByWtzu6vlvRsMe0AqEpu+M3sCUl7JC00swNm9gNJD0q6wcz+JemG7DGAaYTz+Qswf/78ZH3Pnj3J+uzZs5P1bq6Nn3ft++3btyfr999/f7J+7NixZD0l73z+vPXW09OTrH/00Ucta/fdd1/yvY888kiyfvz48WS9TpzPDyCJ8ANBEX4gKMIPBEX4gaAIPxAUQ30FWLBgQbI+PDzc1fLzhvqef/75lrVVq1Yl33v48OGOeqrC3Xffnaxv2rQpWU+tt7zToC+77LJk/a233krW68RQH4Akwg8ERfiBoAg/EBThB4Ii/EBQhB8IqvTLeKF7Q0NDyfrtt9/estbkcfw8g4ODyfqtt96arF911VVFtnPWYcsPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0Exzl+BvPPx81x99dUFdTK9mKVPS89br92s940bNybrt912W8fLbgq2/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QVO44v5ltk/RdSYfc/YrsuY2Sfijp3exl6919Z1lNNt0dd9yRrOddIx5Tu/HGG5P1RYsWJeup9Z733yRvnP9s0M6W/7eSlk7x/C/d/crsn7DBB6ar3PC7+25JRyroBUCFuvnOf5eZ/d3MtpnZuYV1BKASnYZ/i6QvS7pS0qikX7R6oZn1mdmQmaUvRAegUh2F393H3P2Eu5+U9BtJixOv7Xf3Xnfv7bRJAMXrKPxmNnfSw+9Jeq2YdgBUpZ2hvickXSdptpkdkPRTSdeZ2ZWSXNKIpDUl9gigBLnhd/dbpnj60RJ6mbbyxqMj6+npaVm7/PLLk+9dv3590e184t13303Wjx8/XtpnNwVH+AFBEX4gKMIPBEX4gaAIPxAU4QeC4tLdKNWGDRta1tauXVvqZ4+MjLSsrV69Ovne/fv3F9xN87DlB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgGOdHV3buTF+4eeHChRV1crp9+/a1rL3wwgsVdtJMbPmBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjG+QtgZsn6jBnd/Y1dtmxZx+/t7+9P1i+44IKOly3l/7vVOT05l1RPY8sPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0HljvOb2cWSfifpC5JOSup391+b2XmSnpQ0X9KIpJvc/f3yWm2uLVu2JOsPPfRQV8vfsWNHst7NWHrZ4/BlLn/r1q2lLTuCdrb8H0v6sbt/RdLXJa01s8sl3SvpOXe/VNJz2WMA00Ru+N191N1fzu5/KGlY0oWSlksayF42IGlFWU0CKN4Zfec3s/mSFkl6UdIcdx+Vxv9ASDq/6OYAlKftY/vN7HOStkv6kbv/N+949knv65PU11l7AMrS1pbfzD6j8eA/5u5PZ0+PmdncrD5X0qGp3uvu/e7e6+69RTQMoBi54bfxTfyjkobdfdOk0qCkialOV0t6tvj2AJTF3D39ArNrJP1V0qsaH+qTpPUa/97/lKRLJO2XtNLdj+QsK/1h09S8efOS9T179iTrPT09yXqTT5vN621sbKxlbXh4OPnevr70t8XR0dFk/dixY8n62crd2/pOnvud391fkNRqYd86k6YANAdH+AFBEX4gKMIPBEX4gaAIPxAU4QeCyh3nL/TDztJx/jxLlixJ1lesSJ8Tdc899yTrTR7nX7duXcva5s2bi24Han+cny0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTFOP80sHTp0mQ9dd573jTVg4ODyXreFN95l3Pbt29fy9r+/fuT70VnGOcHkET4gaAIPxAU4QeCIvxAUIQfCIrwA0Exzg+cZRjnB5BE+IGgCD8QFOEHgiL8QFCEHwiK8ANB5YbfzC42s+fNbNjMXjeze7LnN5rZf8zsb9k/3ym/XQBFyT3Ix8zmSprr7i+b2ecl7ZW0QtJNko66+8/b/jAO8gFK1+5BPue0saBRSaPZ/Q/NbFjShd21B6BuZ/Sd38zmS1ok6cXsqbvM7O9mts3Mzm3xnj4zGzKzoa46BVCoto/tN7PPSfqLpAfc/WkzmyPpsCSX9DONfzW4PWcZ7PYDJWt3t7+t8JvZZyTtkPQnd980RX2+pB3ufkXOcgg/ULLCTuyx8cuzPippeHLwsx8CJ3xP0mtn2iSA+rTza/81kv4q6VVJE3NBr5d0i6QrNb7bPyJpTfbjYGpZbPmBkhW6218Uwg+Uj/P5ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgsq9gGfBDkt6Z9Lj2dlzTdTU3pral0RvnSqyt3ntvrDS8/lP+3CzIXfvra2BhKb21tS+JHrrVF29sdsPBEX4gaDqDn9/zZ+f0tTemtqXRG+dqqW3Wr/zA6hP3Vt+ADWpJfxmttTM3jCzN83s3jp6aMXMRszs1Wzm4VqnGMumQTtkZq9Neu48M/uzmf0ru51ymrSaemvEzM2JmaVrXXdNm/G68t1+M5sp6Z+SbpB0QNJLkm5x932VNtKCmY1I6nX32seEzWyJpKOSfjcxG5KZPSTpiLs/mP3hPNfdf9KQ3jbqDGduLqm3VjNLf181rrsiZ7wuQh1b/sWS3nT3t939f5J+L2l5DX00nrvvlnTklKeXSxrI7g9o/H+eyrXorRHcfdTdX87ufyhpYmbpWtddoq9a1BH+CyX9e9LjA2rWlN8uaZeZ7TWzvrqbmcKciZmRstvza+7nVLkzN1fplJmlG7PuOpnxumh1hH+q2USaNOTwDXf/mqRlktZmu7dozxZJX9b4NG6jkn5RZzPZzNLbJf3I3f9bZy+TTdFXLeutjvAfkHTxpMcXSTpYQx9TcveD2e0hSc9o/GtKk4xNTJKa3R6quZ9PuPuYu59w95OSfqMa1102s/R2SY+5+9PZ07Wvu6n6qmu91RH+lyRdamZfNLPPSlolabCGPk5jZrOyH2JkZrMkfVvNm314UNLq7P5qSc/W2MunNGXm5lYzS6vmdde0Ga9rOcgnG8r4laSZkra5+wOVNzEFM/uSxrf20vgZj4/X2ZuZPSHpOo2f9TUm6aeS/iDpKUmXSNovaaW7V/7DW4vertMZztxcUm+tZpZ+UTWuuyJnvC6kH47wA2LiCD8gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0H9HwAENgeMtPBpAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "qiwbolEgZ1nw" + }, + "outputs": [], + "source": [ + "# Optimization process. \n", + "def run_optimization(x, y):\n", + " # Wrap computation inside a GradientTape for automatic differentiation.\n", + " with tf.GradientTape() as g:\n", + " # Forward pass.\n", + " pred = conv_net(x, is_training=True)\n", + " # Compute loss.\n", + " loss = cross_entropy_loss(pred, y)\n", + " \n", + " # Variables to update, i.e. trainable variables.\n", + " trainable_variables = conv_net.trainable_variables\n", + "\n", + " # Compute gradients.\n", + " gradients = g.gradient(loss, trainable_variables)\n", + " \n", + " # Update W and b following gradients.\n", + " optimizer.apply_gradients(zip(gradients, trainable_variables))" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model prediction: 0\n" - ] + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Fqpe_wWWZ1nz", + "outputId": "714097be-f518-418b-dcec-9924886be6e2" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 10, loss: 1.900374, accuracy: 0.703125\n", + "step: 20, loss: 1.607977, accuracy: 0.921875\n", + "step: 30, loss: 1.620427, accuracy: 0.898438\n", + "step: 40, loss: 1.579274, accuracy: 0.937500\n", + "step: 50, loss: 1.521082, accuracy: 0.976562\n", + "step: 60, loss: 1.525948, accuracy: 0.976562\n", + "step: 70, loss: 1.548241, accuracy: 0.945312\n", + "step: 80, loss: 1.521636, accuracy: 0.968750\n", + "step: 90, loss: 1.518036, accuracy: 0.960938\n", + "step: 100, loss: 1.533605, accuracy: 0.953125\n", + "step: 110, loss: 1.479001, accuracy: 1.000000\n", + "step: 120, loss: 1.516089, accuracy: 0.960938\n", + "step: 130, loss: 1.504505, accuracy: 0.992188\n", + "step: 140, loss: 1.503561, accuracy: 0.968750\n", + "step: 150, loss: 1.492262, accuracy: 0.976562\n", + "step: 160, loss: 1.475691, accuracy: 1.000000\n", + "step: 170, loss: 1.513242, accuracy: 0.968750\n", + "step: 180, loss: 1.481829, accuracy: 1.000000\n", + "step: 190, loss: 1.497677, accuracy: 0.976562\n", + "step: 200, loss: 1.508653, accuracy: 0.976562\n" + ] + } + ], + "source": [ + "# Run training for the given number of steps.\n", + "for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):\n", + " # Run the optimization to update W and b values.\n", + " run_optimization(batch_x, batch_y)\n", + " \n", + " if step % display_step == 0:\n", + " pred = conv_net(batch_x)\n", + " loss = cross_entropy_loss(pred, batch_y)\n", + " acc = accuracy(pred, batch_y)\n", + " print(\"step: %i, loss: %f, accuracy: %f\" % (step, loss, acc))" + ] }, { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADXZJREFUeJzt3X+oXPWZx/HPZ00bMQ2SS0ga0uzeGmVdCW6qF1GUqhRjNlZi0UhCWLJaevtHhRb3jxUVKmpBZJvd/mMgxdAIbdqicQ219AcS1xUWyY2EmvZu2xiyTZqQH6ahiQSquU//uOfKNblzZjJzZs7c+7xfIDNznnNmHo753O85c2bm64gQgHz+pu4GANSD8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSGpWL1/MNh8nBLosItzKeh2N/LZX2v6t7X22H+nkuQD0ltv9bL/tSyT9TtIdkg5J2iVpXUT8pmQbRn6gy3ox8t8gaV9E7I+Iv0j6oaTVHTwfgB7qJPyLJR2c9PhQsexjbA/bHrE90sFrAahYJ2/4TXVoccFhfURslrRZ4rAf6CedjPyHJC2Z9Pgzkg531g6AXukk/LskXWX7s7Y/KWmtpB3VtAWg29o+7I+ID20/JOnnki6RtCUifl1ZZwC6qu1LfW29GOf8QNf15EM+AKYvwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Jqe4puSbJ9QNJpSeckfRgRQ1U0hY+77rrrSuvbt29vWBscHKy4m/6xYsWK0vro6GjD2sGDB6tuZ9rpKPyF2yPiRAXPA6CHOOwHkuo0/CHpF7Z32x6uoiEAvdHpYf/NEXHY9gJJv7T9fxHxxuQVij8K/GEA+kxHI39EHC5uj0l6WdINU6yzOSKGeDMQ6C9th9/2HNtzJ+5LWiFpb1WNAeiuTg77F0p62fbE8/wgIn5WSVcAuq7t8EfEfkn/WGEvaODOO+8src+ePbtHnfSXu+++u7T+4IMPNqytXbu26namHS71AUkRfiApwg8kRfiBpAg/kBThB5Kq4lt96NCsWeX/G1atWtWjTqaX3bt3l9YffvjhhrU5c+aUbvv++++31dN0wsgPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0lxnb8P3H777aX1m266qbT+7LPPVtnOtDFv3rzS+jXXXNOwdtlll5Vuy3V+ADMW4QeSIvxAUoQfSIrwA0kRfiApwg8k5Yjo3YvZvXuxPrJs2bLS+uuvv15af++990rr119/fcPamTNnSredzprtt1tuuaVhbdGiRaXbHj9+vJ2W+kJEuJX1GPmBpAg/kBThB5Ii/EBShB9IivADSRF+IKmm3+e3vUXSFyUdi4hlxbIBST+SNCjpgKT7I+JP3Wtzenv88cdL681+Q37lypWl9Zl6LX9gYKC0fuutt5bWx8bGqmxnxmll5P+epPP/9T0i6bWIuErSa8VjANNI0/BHxBuSTp63eLWkrcX9rZLuqbgvAF3W7jn/wog4IknF7YLqWgLQC13/DT/bw5KGu/06AC5OuyP/UduLJKm4PdZoxYjYHBFDETHU5msB6IJ2w79D0obi/gZJr1TTDoBeaRp+29sk/a+kv7d9yPaXJT0j6Q7bv5d0R/EYwDTS9Jw/ItY1KH2h4l6mrfvuu6+0vmrVqtL6vn37SusjIyMX3dNM8Nhjj5XWm13HL/u+/6lTp9ppaUbhE35AUoQfSIrwA0kRfiApwg8kRfiBpJiiuwJr1qwprTebDvq5556rsp1pY3BwsLS+fv360vq5c+dK608//XTD2gcffFC6bQaM/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFNf5W3T55Zc3rN14440dPfemTZs62n66Gh4u/3W3+fPnl9ZHR0dL6zt37rzonjJh5AeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpLjO36LZs2c3rC1evLh0223btlXdzoywdOnSjrbfu3dvRZ3kxMgPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0k1vc5ve4ukL0o6FhHLimVPSPqKpOPFao9GxE+71WQ/OH36dMPanj17Sre99tprS+sDAwOl9ZMnT5bW+9mCBQsa1ppNbd7Mm2++2dH22bUy8n9P0soplv9HRCwv/pvRwQdmoqbhj4g3JE3foQfAlDo553/I9q9sb7E9r7KOAPREu+HfJGmppOWSjkj6dqMVbQ/bHrE90uZrAeiCtsIfEUcj4lxEjEn6rqQbStbdHBFDETHUbpMAqtdW+G0vmvTwS5L4ehUwzbRyqW+bpNskzbd9SNI3Jd1me7mkkHRA0le72COALmga/ohYN8Xi57vQS187e/Zsw9q7775buu29995bWn/11VdL6xs3biytd9OyZctK61dccUVpfXBwsGEtItpp6SNjY2MdbZ8dn/ADkiL8QFKEH0iK8ANJEX4gKcIPJOVOL7dc1IvZvXuxHrr66qtL608++WRp/a677iqtl/1seLedOHGitN7s30/ZNNu22+ppwty5c0vrZZdnZ7KIaGnHMvIDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFJc5+8Dy5cvL61feeWVPerkQi+++GJH22/durVhbf369R0996xZzDA/Fa7zAyhF+IGkCD+QFOEHkiL8QFKEH0iK8ANJcaG0DzSb4rtZvZ/t37+/a8/d7GfF9+5lLpkyjPxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kFTT6/y2l0h6QdKnJY1J2hwR37E9IOlHkgYlHZB0f0T8qXutYjoq+23+Tn+3n+v4nWll5P9Q0r9GxD9IulHS12xfI+kRSa9FxFWSXiseA5gmmoY/Io5ExNvF/dOSRiUtlrRa0sTPtGyVdE+3mgRQvYs657c9KOlzkt6StDAijkjjfyAkLai6OQDd0/Jn+21/StJLkr4REX9u9XzN9rCk4fbaA9AtLY38tj+h8eB/PyK2F4uP2l5U1BdJOjbVthGxOSKGImKoioYBVKNp+D0+xD8vaTQiNk4q7ZC0obi/QdIr1bcHoFtaOey/WdI/S3rH9sR3Sx+V9IykH9v+sqQ/SFrTnRYxnZX9NHwvfzYeF2oa/oh4U1KjE/wvVNsOgF7hE35AUoQfSIrwA0kRfiApwg8kRfiBpPjpbnTVpZde2va2Z8+erbATnI+RH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeS4jo/uuqBBx5oWDt16lTptk899VTV7WASRn4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrr/OiqXbt2Naxt3LixYU2Sdu7cWXU7mISRH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeScrM50m0vkfSCpE9LGpO0OSK+Y/sJSV+RdLxY9dGI+GmT52JCdqDLIsKtrNdK+BdJWhQRb9ueK2m3pHsk3S/pTET8e6tNEX6g+1oNf9NP+EXEEUlHivunbY9KWtxZewDqdlHn/LYHJX1O0lvFoods/8r2FtvzGmwzbHvE9khHnQKoVNPD/o9WtD8l6b8lfSsittteKOmEpJD0lMZPDR5s8hwc9gNdVtk5vyTZ/oSkn0j6eURc8G2M4ojgJxGxrMnzEH6gy1oNf9PDftuW9Lyk0cnBL94InPAlSXsvtkkA9Wnl3f5bJP2PpHc0fqlPkh6VtE7Sco0f9h+Q9NXizcGy52LkB7qs0sP+qhB+oPsqO+wHMDMRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkur1FN0nJP3/pMfzi2X9qF9769e+JHprV5W9/V2rK/b0+/wXvLg9EhFDtTVQol9769e+JHprV129cdgPJEX4gaTqDv/mml+/TL/21q99SfTWrlp6q/WcH0B96h75AdSklvDbXmn7t7b32X6kjh4asX3A9ju299Q9xVgxDdox23snLRuw/Uvbvy9up5wmrabenrD9x2Lf7bG9qqbeltjeaXvU9q9tf71YXuu+K+mrlv3W88N+25dI+p2kOyQdkrRL0rqI+E1PG2nA9gFJQxFR+zVh25+XdEbSCxOzIdl+VtLJiHim+MM5LyL+rU96e0IXOXNzl3prNLP0v6jGfVfljNdVqGPkv0HSvojYHxF/kfRDSatr6KPvRcQbkk6et3i1pK3F/a0a/8fTcw166wsRcSQi3i7un5Y0MbN0rfuupK9a1BH+xZIOTnp8SP015XdI+oXt3baH625mCgsnZkYqbhfU3M/5ms7c3EvnzSzdN/uunRmvq1ZH+KeaTaSfLjncHBHXSfonSV8rDm/Rmk2Slmp8Grcjkr5dZzPFzNIvSfpGRPy5zl4mm6KvWvZbHeE/JGnJpMefkXS4hj6mFBGHi9tjkl7W+GlKPzk6MUlqcXus5n4+EhFHI+JcRIxJ+q5q3HfFzNIvSfp+RGwvFte+76bqq679Vkf4d0m6yvZnbX9S0lpJO2ro4wK25xRvxMj2HEkr1H+zD++QtKG4v0HSKzX28jH9MnNzo5mlVfO+67cZr2v5kE9xKeM/JV0iaUtEfKvnTUzB9hUaH+2l8W88/qDO3mxvk3Sbxr/1dVTSNyX9l6QfS/pbSX+QtCYiev7GW4PebtNFztzcpd4azSz9lmrcd1XOeF1JP3zCD8iJT/gBSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0jqr8DO4JozFB6IAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5W0CXC9QZ1n0", + "outputId": "18f6dc26-f0c6-4e05-ce6c-c920a55b3fd0" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Test Accuracy: 0.974100\n" + ] + } + ], + "source": [ + "# Test model on validation set.\n", + "pred = conv_net(x_test)\n", + "print(\"Test Accuracy: %f\" % accuracy(pred, y_test))" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model prediction: 4\n" - ] + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "AjtzFhVqZ1n1" + }, + "outputs": [], + "source": [ + "# Visualize predictions.\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "-veF_PVHZ1n2", + "outputId": "e594df75-2708-4ece-c4c1-ecaba3aab79f" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAM3ElEQVR4nO3dXahc9bnH8d/vpCmI6UXiS9ik0bTBC8tBEo1BSCxbQktOvIjFIM1FyYHi7kWUFkuo2It4WaQv1JvALkrTkmMJpGoQscmJxVDU4o5Es2NIjCGaxLxYIjQRJMY+vdjLso0za8ZZa2ZN8nw/sJmZ9cya9bDMz7VmvczfESEAV77/aroBAINB2IEkCDuQBGEHkiDsQBJfGeTCbHPoH+iziHCr6ZW27LZX2j5o+7Dth6t8FoD+cq/n2W3PkHRI0nckHZf0mqS1EfFWyTxs2YE+68eWfamkwxFxJCIuSPqTpNUVPg9AH1UJ+zxJx6a9Pl5M+xzbY7YnbE9UWBaAivp+gC4ixiWNS+zGA02qsmU/IWn+tNdfL6YBGEJVwv6apJtsf8P2VyV9X9L2etoCULeed+Mj4qLtByT9RdIMSU9GxP7aOgNQq55PvfW0ML6zA33Xl4tqAFw+CDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9Dw+uyTZPirpnKRPJV2MiCV1NAWgfpXCXrgrIv5Rw+cA6CN244EkqoY9JO2wvcf2WKs32B6zPWF7ouKyAFTgiOh9ZnteRJywfb2knZIejIjdJe/vfWEAuhIRbjW90pY9Ik4Uj2ckPS1paZXPA9A/PYfd9tW2v/bZc0nflTRZV2MA6lXlaPxcSU/b/uxz/i8iXqilKwC1q/Sd/UsvjO/sQN/15Ts7gMsHYQeSIOxAEoQdSIKwA0nUcSNMCmvWrGlbu//++0vnff/990vrH3/8cWl9y5YtpfVTp061rR0+fLh0XuTBlh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuCuty4dOXKkbW3BggWDa6SFc+fOta3t379/gJ0Ml+PHj7etPfbYY6XzTkxcvr+ixl1vQHKEHUiCsANJEHYgCcIOJEHYgSQIO5AE97N3qeye9VtuuaV03gMHDpTWb7755tL6rbfeWlofHR1tW7vjjjtK5z127Fhpff78+aX1Ki5evFha/+CDD0rrIyMjPS/7vffeK61fzufZ22HLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcD/7FWD27Nlta4sWLSqdd8+ePaX122+/vaeeutHp9/IPHTpUWu90/cKcOXPa1tavX18676ZNm0rrw6zn+9ltP2n7jO3JadPm2N5p++3isf2/NgBDoZvd+N9LWnnJtIcl7YqImyTtKl4DGGIdwx4RuyWdvWTyakmbi+ebJd1Tc18AatbrtfFzI+Jk8fyUpLnt3mh7TNJYj8sBUJPKN8JERJQdeIuIcUnjEgfogCb1eurttO0RSSoez9TXEoB+6DXs2yWtK56vk/RsPe0A6JeO59ltPyVpVNK1kk5L2ijpGUlbJd0g6V1J90XEpQfxWn0Wu/Ho2r333lta37p1a2l9cnKybe2uu+4qnffs2Y7/nIdWu/PsHb+zR8TaNqUVlToCMFBcLgskQdiBJAg7kARhB5Ig7EAS3OKKxlx//fWl9X379lWaf82aNW1r27ZtK533csaQzUByhB1IgrADSRB2IAnCDiRB2IEkCDuQBEM2ozGdfs75uuuuK61/+OGHpfWDBw9+6Z6uZGzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJ7mdHXy1btqxt7cUXXyydd+bMmaX10dHR0vru3btL61cq7mcHkiPsQBKEHUiCsANJEHYgCcIOJEHYgSS4nx19tWrVqra1TufRd+3aVVp/5ZVXeuopq45bdttP2j5je3LatEdtn7C9t/hr/18UwFDoZjf+95JWtpj+m4hYVPw9X29bAOrWMewRsVvS2QH0AqCPqhyge8D2m8Vu/ux2b7I9ZnvC9kSFZQGoqNewb5K0UNIiSScl/ardGyNiPCKWRMSSHpcFoAY9hT0iTkfEpxHxL0m/k7S03rYA1K2nsNsemfbye5Im270XwHDoeJ7d9lOSRiVda/u4pI2SRm0vkhSSjkr6UR97xBC76qqrSusrV7Y6kTPlwoULpfNu3LixtP7JJ5+U1vF5HcMeEWtbTH6iD70A6CMulwWSIOxAEoQdSIKwA0kQdiAJbnFFJRs2bCitL168uG3thRdeKJ335Zdf7qkntMaWHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYMhmlLr77rtL688880xp/aOPPmpbK7v9VZJeffXV0jpaY8hmIDnCDiRB2IEkCDuQBGEHkiDsQBKEHUiC+9mTu+aaa0rrjz/+eGl9xowZpfXnn28/5ifn0QeLLTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMH97Fe4TufBO53rvu2220rr77zzTmm97J71TvOiNz3fz257vu2/2n7L9n7bPy6mz7G90/bbxePsupsGUJ9uduMvSvppRHxL0h2S1tv+lqSHJe2KiJsk7SpeAxhSHcMeEScj4vXi+TlJByTNk7Ra0ubibZsl3dOvJgFU96Wujbe9QNJiSX+XNDciThalU5LmtplnTNJY7y0CqEPXR+Ntz5K0TdJPIuKf02sxdZSv5cG3iBiPiCURsaRSpwAq6SrstmdqKuhbIuLPxeTTtkeK+oikM/1pEUAdOu7G27akJyQdiIhfTyttl7RO0i+Kx2f70iEqWbhwYWm906m1Th566KHSOqfXhkc339mXSfqBpH229xbTHtFUyLfa/qGkdyXd158WAdShY9gj4m+SWp6kl7Si3nYA9AuXywJJEHYgCcIOJEHYgSQIO5AEPyV9Bbjxxhvb1nbs2FHpszds2FBaf+655yp9PgaHLTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMF59ivA2Fj7X/264YYbKn32Sy+9VFof5E+Roxq27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZLwPLly8vrT/44IMD6gSXM7bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5BEN+Ozz5f0B0lzJYWk8Yj4re1HJd0v6YPirY9ExPP9ajSzO++8s7Q+a9asnj+70/jp58+f7/mzMVy6uajmoqSfRsTrtr8maY/tnUXtNxHxy/61B6Au3YzPflLSyeL5OdsHJM3rd2MA6vWlvrPbXiBpsaS/F5MesP2m7Sdtz24zz5jtCdsTlToFUEnXYbc9S9I2ST+JiH9K2iRpoaRFmtry/6rVfBExHhFLImJJDf0C6FFXYbc9U1NB3xIRf5akiDgdEZ9GxL8k/U7S0v61CaCqjmG3bUlPSDoQEb+eNn1k2tu+J2my/vYA1KWbo/HLJP1A0j7be4tpj0haa3uRpk7HHZX0o750iEreeOON0vqKFStK62fPnq2zHTSom6Pxf5PkFiXOqQOXEa6gA5Ig7EAShB1IgrADSRB2IAnCDiThQQ65a5vxfYE+i4hWp8rZsgNZEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoMesvkfkt6d9vraYtowGtbehrUvid56VWdvN7YrDPSimi8s3J4Y1t+mG9behrUvid56Naje2I0HkiDsQBJNh3284eWXGdbehrUvid56NZDeGv3ODmBwmt6yAxgQwg4k0UjYba+0fdD2YdsPN9FDO7aP2t5ne2/T49MVY+idsT05bdoc2zttv108thxjr6HeHrV9olh3e22vaqi3+bb/avst2/tt/7iY3ui6K+lrIOtt4N/Zbc+QdEjSdyQdl/SapLUR8dZAG2nD9lFJSyKi8QswbH9b0nlJf4iI/y6mPSbpbET8ovgf5eyI+NmQ9PaopPNND+NdjFY0Mn2YcUn3SPpfNbjuSvq6TwNYb01s2ZdKOhwRRyLigqQ/SVrdQB9DLyJ2S7p0SJbVkjYXzzdr6h/LwLXpbShExMmIeL14fk7SZ8OMN7ruSvoaiCbCPk/SsWmvj2u4xnsPSTts77E91nQzLcyNiJPF81OS5jbZTAsdh/EepEuGGR+addfL8OdVcYDui5ZHxK2S/kfS+mJ3dSjF1HewYTp32tUw3oPSYpjx/2hy3fU6/HlVTYT9hKT5015/vZg2FCLiRPF4RtLTGr6hqE9/NoJu8Xim4X7+Y5iG8W41zLiGYN01Ofx5E2F/TdJNtr9h+6uSvi9pewN9fIHtq4sDJ7J9taTvaviGot4uaV3xfJ2kZxvs5XOGZRjvdsOMq+F11/jw5xEx8D9JqzR1RP4dST9vooc2fX1T0hvF3/6me5P0lKZ26z7R1LGNH0q6RtIuSW9L+n9Jc4aotz9K2ifpTU0Fa6Sh3pZrahf9TUl7i79VTa+7kr4Gst64XBZIggN0QBKEHUiCsANJEHYgCcIOJEHYgSQIO5DEvwEvYRv57rmVLgAAAABJRU5ErkJggg==\n" + }, + "metadata": { + "needs_background": "light" + } + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model prediction: 7\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANYElEQVR4nO3df4hd9ZnH8c9n3QTEFk0iOwxG1hr1j7iolVEWVxaX2uiKJgakJshiqTD9o0LF+CNkhQiLKLvb3T8DUxoatWvTkJjGumzqhvpjwQRHiTHRtBpJbMIkQzZgE0Rqkmf/mDPLVOeeOznn3ntu8rxfMNx7z3PvOQ9XPzm/7jlfR4QAnPv+rOkGAPQGYQeSIOxAEoQdSIKwA0n8eS8XZptD/0CXRYSnm15rzW77dtu/tf2R7ZV15gWgu1z1PLvt8yT9TtK3JR2U9Jak5RHxfslnWLMDXdaNNfuNkj6KiI8j4o+Sfi5pSY35AeiiOmG/RNLvp7w+WEz7E7aHbY/aHq2xLAA1df0AXUSMSBqR2IwHmlRnzX5I0qVTXs8vpgHoQ3XC/pakK21/w/ZsScskbelMWwA6rfJmfESctP2gpK2SzpO0NiL2dKwzAB1V+dRbpYWxzw50XVd+VAPg7EHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASPb2VNKp55JFHSuvnn39+y9o111xT+tl77rmnUk+T1qxZU1p/8803W9aee+65WsvGmWHNDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcHfZPrB+/frSet1z4U3at29fy9qtt95a+tlPPvmk0+2kwN1lgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJrmfvgSbPo+/du7e0vnXr1tL65ZdfXlq/6667SusLFixoWbvvvvtKP/v000+X1nFmaoXd9n5JxyWdknQyIoY60RSAzuvEmv3vIuJoB+YDoIvYZweSqBv2kPRr22/bHp7uDbaHbY/aHq25LAA11N2MvzkiDtn+C0mv2N4bEa9PfUNEjEgakbgQBmhSrTV7RBwqHsclvSjpxk40BaDzKofd9gW2vz75XNIiSbs71RiAzqqzGT8g6UXbk/P5j4j4r450dZYZGio/47h06dJa89+zZ09pffHixS1rR4+Wnyg5ceJEaX327Nml9e3bt5fWr7322pa1efPmlX4WnVU57BHxsaTW/yUB9BVOvQFJEHYgCcIOJEHYgSQIO5AEl7h2wODgYGm9OD3ZUrtTa7fddltpfWxsrLRex4oVK0rrCxcurDzvl19+ufJnceZYswNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEpxn74CXXnqptH7FFVeU1o8fP15aP3bs2Bn31CnLli0rrc+aNatHnaAu1uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATn2XvgwIEDTbfQ0qOPPlpav+qqq2rNf8eOHZVq6DzW7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQhCOidwuze7cwSJLuvPPO0vqGDRtK6+2GbB4fHy+tl10P/9prr5V+FtVExLQDFbRds9tea3vc9u4p0+bafsX2h8XjnE42C6DzZrIZ/1NJt39p2kpJ2yLiSknbitcA+ljbsEfE65K+fF+kJZLWFc/XSbq7w30B6LCqv40fiIjJAcYOSxpo9Ubbw5KGKy4HQIfUvhAmIqLswFtEjEgakThABzSp6qm3I7YHJal4LD8kC6BxVcO+RdL9xfP7Jf2yM+0A6Ja2m/G2X5B0i6SLbR+UtFrSM5J+YfsBSQckfaebTaK6oaGh0nq78+jtrF+/vrTOufT+0TbsEbG8RelbHe4FQBfxc1kgCcIOJEHYgSQIO5AEYQeS4FbS54DNmze3rC1atKjWvJ999tnS+hNPPFFr/ugd1uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kAS3kj4LDA4OltbffffdlrV58+aVfvbo0aOl9Ztuuqm0vm/fvtI6eq/yraQBnBsIO5AEYQeSIOxAEoQdSIKwA0kQdiAJrmc/C2zcuLG03u5cepnnn3++tM559HMHa3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSILz7H1g8eLFpfXrr7++8rxfffXV0vrq1asrzxtnl7ZrdttrbY/b3j1l2pO2D9neWfzd0d02AdQ1k834n0q6fZrp/x4R1xV//9nZtgB0WtuwR8Trko71oBcAXVTnAN2DtncVm/lzWr3J9rDtUdujNZYFoKaqYV8jaYGk6ySNSfpRqzdGxEhEDEXEUMVlAeiASmGPiCMRcSoiTkv6saQbO9sWgE6rFHbbU+9tvFTS7lbvBdAf2p5nt/2CpFskXWz7oKTVkm6xfZ2kkLRf0ve72ONZr9315qtWrSqtz5o1q/Kyd+7cWVo/ceJE5Xnj7NI27BGxfJrJP+lCLwC6iJ/LAkkQdiAJwg4kQdiBJAg7kASXuPbAihUrSus33HBDrflv3ry5ZY1LWDGJNTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJOGI6N3C7N4trI98/vnnpfU6l7BK0vz581vWxsbGas0bZ5+I8HTTWbMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBJcz34OmDt3bsvaF1980cNOvurTTz9tWWvXW7vfH1x44YWVepKkiy66qLT+8MMPV573TJw6dapl7fHHHy/97GeffVZpmazZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJzrOfA3bt2tV0Cy1t2LChZa3dtfYDAwOl9XvvvbdST/3u8OHDpfWnnnqq0nzbrtltX2r7N7bft73H9g+L6XNtv2L7w+JxTqUOAPTETDbjT0paERELJf21pB/YXihppaRtEXGlpG3FawB9qm3YI2IsIt4pnh+X9IGkSyQtkbSueNs6SXd3q0kA9Z3RPrvtyyR9U9IOSQMRMbnTdVjStDtYtoclDVdvEUAnzPhovO2vSdoo6aGI+MPUWkzctXLam0lGxEhEDEXEUK1OAdQyo7DbnqWJoP8sIjYVk4/YHizqg5LGu9MigE5oeytp29bEPvmxiHhoyvR/kfS/EfGM7ZWS5kbEY23mlfJW0ps2bSqtL1mypEed5HLy5MmWtdOnT9ea95YtW0rro6Ojlef9xhtvlNa3b99eWm91K+mZ7LP/jaR/kPSe7Z3FtFWSnpH0C9sPSDog6TszmBeAhrQNe0T8j6Rp/6WQ9K3OtgOgW/i5LJAEYQeSIOxAEoQdSIKwA0kwZHMfeOyx0p8n1B7SuczVV19dWu/mZaRr164tre/fv7/W/Ddu3Niytnfv3lrz7mcM2QwkR9iBJAg7kARhB5Ig7EAShB1IgrADSXCeHTjHcJ4dSI6wA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmgbdtuX2v6N7fdt77H9w2L6k7YP2d5Z/N3R/XYBVNX25hW2ByUNRsQ7tr8u6W1Jd2tiPPYTEfGvM14YN68Auq7VzStmMj77mKSx4vlx2x9IuqSz7QHotjPaZ7d9maRvStpRTHrQ9i7ba23PafGZYdujtkdrdQqglhnfg8721yS9JumpiNhke0DSUUkh6Z80san/vTbzYDMe6LJWm/EzCrvtWZJ+JWlrRPzbNPXLJP0qIv6qzXwIO9BllW84aduSfiLpg6lBLw7cTVoqaXfdJgF0z0yOxt8s6Q1J70k6XUxeJWm5pOs0sRm/X9L3i4N5ZfNizQ50Wa3N+E4h7ED3cd94IDnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEm1vONlhRyUdmPL64mJaP+rX3vq1L4nequpkb3/ZqtDT69m/snB7NCKGGmugRL/21q99SfRWVa96YzMeSIKwA0k0HfaRhpdfpl9769e+JHqrqie9NbrPDqB3ml6zA+gRwg4k0UjYbd9u+7e2P7K9sokeWrG93/Z7xTDUjY5PV4yhN25795Rpc22/YvvD4nHaMfYa6q0vhvEuGWa80e+u6eHPe77Pbvs8Sb+T9G1JByW9JWl5RLzf00ZasL1f0lBENP4DDNt/K+mEpGcnh9ay/c+SjkXEM8U/lHMi4vE+6e1JneEw3l3qrdUw499Vg99dJ4c/r6KJNfuNkj6KiI8j4o+Sfi5pSQN99L2IeF3SsS9NXiJpXfF8nSb+Z+m5Fr31hYgYi4h3iufHJU0OM97od1fSV080EfZLJP1+yuuD6q/x3kPSr22/bXu46WamMTBlmK3DkgaabGYabYfx7qUvDTPeN99dleHP6+IA3VfdHBHXS/p7ST8oNlf7Ukzsg/XTudM1khZoYgzAMUk/arKZYpjxjZIeiog/TK01+d1N01dPvrcmwn5I0qVTXs8vpvWFiDhUPI5LelETux395MjkCLrF43jD/fy/iDgSEaci4rSkH6vB764YZnyjpJ9FxKZicuPf3XR99ep7ayLsb0m60vY3bM+WtEzSlgb6+ArbFxQHTmT7AkmL1H9DUW+RdH/x/H5Jv2ywlz/RL8N4txpmXA1/d40Pfx4RPf+TdIcmjsjvk/SPTfTQoq/LJb1b/O1pujdJL2his+4LTRzbeEDSPEnbJH0o6b8lze2j3p7TxNDeuzQRrMGGertZE5vouyTtLP7uaPq7K+mrJ98bP5cFkuAAHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4k8X98jzceoKWtgAAAAABJRU5ErkJggg==\n" + }, + "metadata": { + "needs_background": "light" + } + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model prediction: 2\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAMEElEQVR4nO3dXYhc5R3H8d+vabwwepFUE4OKsRJRUUzKIoKhWnzBBiHmRoxQEiqsFwYi9KJiLxRKQaTaCy+EFcU0WF+IBqPWaBrEtDeaVVNNfIlWIiasWSWCb4g1+fdiT8oad85s5pwzZ9z/9wPLzDzPnDl/DvnlOXNe5nFECMDM95O2CwDQH4QdSIKwA0kQdiAJwg4k8dN+rsw2h/6BhkWEp2qvNLLbvtr2u7bft31rlc8C0Cz3ep7d9ixJeyRdKWmfpB2SVkXEWyXLMLIDDWtiZL9I0vsR8UFEfCvpUUkrKnwegAZVCfupkj6a9Hpf0fY9todtj9oerbAuABU1foAuIkYkjUjsxgNtqjKy75d0+qTXpxVtAAZQlbDvkLTY9pm2j5N0vaTN9ZQFoG4978ZHxHe210p6XtIsSQ9GxO7aKgNQq55PvfW0Mr6zA41r5KIaAD8ehB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dcpm5HP2Wef3bHvnXfeKV123bp1pf333ntvTzVlxcgOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0lwnh2NWrp0ace+w4cPly67b9++ustJrVLYbe+V9IWkQ5K+i4ihOooCUL86RvZfRcSnNXwOgAbxnR1IomrYQ9ILtl+1PTzVG2wP2x61PVpxXQAqqLobvywi9tueL2mr7XciYvvkN0TEiKQRSbIdFdcHoEeVRvaI2F88jkvaJOmiOooCUL+ew257ju0TjzyXdJWkXXUVBqBeVXbjF0jaZPvI5/wtIrbUUhVmjCVLlnTs++qrr0qX3bRpU93lpNZz2CPiA0kX1lgLgAZx6g1IgrADSRB2IAnCDiRB2IEkuMUVlZx//vml/WvXru3Yt2HDhrrLQQlGdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgvPsqOScc84p7Z8zZ07Hvscee6zuclCCkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHknBE/yZpYUaYmeeVV14p7T/55JM79nW7F77bT01jahHhqdoZ2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCe5nR6lFixaV9g8NDZX279mzp2Mf59H7q+vIbvtB2+O2d01qm2d7q+33ise5zZYJoKrp7MY/JOnqo9pulbQtIhZL2la8BjDAuoY9IrZLOnhU8wpJ64vn6yVdW3NdAGrW63f2BRExVjz/WNKCTm+0PSxpuMf1AKhJ5QN0ERFlN7hExIikEYkbYYA29Xrq7YDthZJUPI7XVxKAJvQa9s2SVhfPV0t6qp5yADSl62687UckXSbpJNv7JN0u6U5Jj9u+UdKHkq5rski059JLL620/CeffFJTJaiqa9gjYlWHrstrrgVAg7hcFkiCsANJEHYgCcIOJEHYgSS4xRWlLrjggkrL33XXXTVVgqoY2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCaZsTu7iiy8u7X/22WdL+/fu3Vvaf8kll3Ts++abb0qXRW+YshlIjrADSRB2IAnCDiRB2IEkCDuQBGEHkuB+9uSuuOKK0v558+aV9m/ZsqW0n3Ppg4ORHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4Dx7chdeeGFpf7ffO9i4cWOd5aBBXUd22w/aHre9a1LbHbb3295Z/C1vtkwAVU1nN/4hSVdP0f6XiFhS/P293rIA1K1r2CNiu6SDfagFQIOqHKBba/uNYjd/bqc32R62PWp7tMK6AFTUa9jvk3SWpCWSxiTd3emNETESEUMRMdTjugDUoKewR8SBiDgUEYcl3S/ponrLAlC3nsJue+Gklysl7er0XgCDoevvxtt+RNJlkk6SdEDS7cXrJZJC0l5JN0XEWNeV8bvxfXfKKaeU9u/cubO0/7PPPivtP/fcc4+5JjSr0+/Gd72oJiJWTdH8QOWKAPQVl8sCSRB2IAnCDiRB2IEkCDuQBLe4znBr1qwp7Z8/f35p/3PPPVdjNWgTIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMF59hnujDPOqLR8t1tc8ePByA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSXCefYa75pprKi3/9NNP11QJ2sbIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJ59Bli2bFnHvm5TNiOPriO77dNtv2j7Ldu7ba8r2ufZ3mr7veJxbvPlAujVdHbjv5P0u4g4T9LFkm62fZ6kWyVti4jFkrYVrwEMqK5hj4ixiHiteP6FpLclnSpphaT1xdvWS7q2qSIBVHdM39ltL5K0VNLLkhZExFjR9bGkBR2WGZY03HuJAOow7aPxtk+Q9ISkWyLi88l9ERGSYqrlImIkIoYiYqhSpQAqmVbYbc/WRNAfjogni+YDthcW/QsljTdTIoA6dN2Nt21JD0h6OyLumdS1WdJqSXcWj081UiG6WrlyZce+WbNmlS77+uuvl/Zv3769p5oweKbznf0SSb+R9KbtnUXbbZoI+eO2b5T0oaTrmikRQB26hj0i/iXJHbovr7ccAE3hclkgCcIOJEHYgSQIO5AEYQeS4BbXH4Hjjz++tH/58uU9f/bGjRtL+w8dOtTzZ2OwMLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKe+JGZPq3M7t/KZpDZs2eX9r/00ksd+8bHy39T5IYbbijt//rrr0v7MXgiYsq7VBnZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJzrMDMwzn2YHkCDuQBGEHkiDsQBKEHUiCsANJEHYgia5ht3267Rdtv2V7t+11Rfsdtvfb3ln89f7j5QAa1/WiGtsLJS2MiNdsnyjpVUnXamI+9i8j4s/TXhkX1QCN63RRzXTmZx+TNFY8/8L225JOrbc8AE07pu/sthdJWirp5aJpre03bD9oe26HZYZtj9oerVQpgEqmfW287RMkvSTpTxHxpO0Fkj6VFJL+qIld/d92+Qx244GGddqNn1bYbc+W9Iyk5yPinin6F0l6JiLO7/I5hB1oWM83wti2pAckvT056MWBuyNWStpVtUgAzZnO0fhlkv4p6U1Jh4vm2yStkrREE7vxeyXdVBzMK/ssRnagYZV24+tC2IHmcT87kBxhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgia4/OFmzTyV9OOn1SUXbIBrU2ga1LonaelVnbWd06ujr/ew/WLk9GhFDrRVQYlBrG9S6JGrrVb9qYzceSIKwA0m0HfaRltdfZlBrG9S6JGrrVV9qa/U7O4D+aXtkB9AnhB1IopWw277a9ru237d9axs1dGJ7r+03i2moW52frphDb9z2rklt82xvtf1e8TjlHHst1TYQ03iXTDPe6rZre/rzvn9ntz1L0h5JV0raJ2mHpFUR8VZfC+nA9l5JQxHR+gUYtn8p6UtJfz0ytZbtuyQdjIg7i/8o50bE7wektjt0jNN4N1Rbp2nG16jFbVfn9Oe9aGNkv0jS+xHxQUR8K+lRSStaqGPgRcR2SQePal4haX3xfL0m/rH0XYfaBkJEjEXEa8XzLyQdmWa81W1XUldftBH2UyV9NOn1Pg3WfO8h6QXbr9oebruYKSyYNM3Wx5IWtFnMFLpO491PR00zPjDbrpfpz6viAN0PLYuIX0j6taSbi93VgRQT38EG6dzpfZLO0sQcgGOS7m6zmGKa8Sck3RIRn0/ua3PbTVFXX7ZbG2HfL+n0Sa9PK9oGQkTsLx7HJW3SxNeOQXLgyAy6xeN4y/X8X0QciIhDEXFY0v1qcdsV04w/IenhiHiyaG59201VV7+2Wxth3yFpse0zbR8n6XpJm1uo4wdszykOnMj2HElXafCmot4saXXxfLWkp1qs5XsGZRrvTtOMq+Vt1/r05xHR9z9JyzVxRP4/kv7QRg0d6vq5pH8Xf7vbrk3SI5rYrfuvJo5t3CjpZ5K2SXpP0j8kzRug2jZoYmrvNzQRrIUt1bZME7vob0jaWfwtb3vbldTVl+3G5bJAEhygA5Ig7EAShB1IgrADSRB2IAnCDiRB2IEk/gciQMnFg+KOfAAAAABJRU5ErkJggg==\n" + }, + "metadata": { + "needs_background": "light" + } + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model prediction: 1\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANsklEQVR4nO3df4hV95nH8c+jbf+x/UPrrJg01bYGgyxsXIwpNJhsSosGgvaPNEoILimMCSYaWNiKQmoohZBss/9ElCkNnS1tSsFkO4hsTUXWDUjJGPLDzGybH6hVJmOMkEYk1OjTP+4xjDrneyb3nHPPGZ/3C4Z773nuPffJST45597vPedr7i4A174ZTTcAoDcIOxAEYQeCIOxAEIQdCOJzvXwzM+Orf6Bm7m6TLS+1ZzezlWb2JzN728y2lFkXgHpZt+PsZjZT0p8lfUfSCUkvS1rn7iOJ17BnB2pWx559uaS33f1dd/+bpN9IWl1ifQBqVCbs10v6y4THJ7JllzGzfjMbNrPhEu8FoKTav6Bz9wFJAxKH8UCTyuzZT0q6YcLjr2TLALRQmbC/LOlGM/uamX1B0lpJQ9W0BaBqXR/Gu/snZvawpN9LminpWXd/s7LOAFSq66G3rt6Mz+xA7Wr5UQ2A6YOwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSB6OmUzem/WrFnJ+lNPPZWsb9iwIVk/fPhwsn7PPffk1o4dO5Z8LarFnh0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgmAW12vcokWLkvXR0dFS658xI72/2LRpU25tx44dpd4bk8ubxbXUj2rM7KikjyRdkPSJuy8rsz4A9aniF3T/4u6nK1gPgBrxmR0IomzYXdI+MztsZv2TPcHM+s1s2MyGS74XgBLKHsbf5u4nzewfJL1oZv/v7gcnPsHdByQNSHxBBzSp1J7d3U9mt6ckvSBpeRVNAahe12E3s1lm9qVL9yV9V9KRqhoDUK0yh/HzJL1gZpfW82t3/59KusJn0tfXl1sbHBzsYSdos67D7u7vSvqnCnsBUCOG3oAgCDsQBGEHgiDsQBCEHQiCS0lPA6nTRCVpzZo1ubXly5v9ndOKFStya0Wnx7722mvJ+sGDB5N1XI49OxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EwaWkp4ELFy4k6xcvXuxRJ1crGisv01vRlM733ntvsl40nfS1Ku9S0uzZgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIxtlbYO/evcn6qlWrkvUmx9k/+OCDZP3s2bO5tQULFlTdzmVmzpxZ6/rbinF2IDjCDgRB2IEgCDsQBGEHgiDsQBCEHQiC68b3wO23356sL168OFkvGkevc5x9165dyfq+ffuS9Q8//DC3dueddyZfu23btmS9yEMPPZRb27lzZ6l1T0eFe3Yze9bMTpnZkQnL5pjZi2b2VnY7u942AZQ1lcP4X0haecWyLZL2u/uNkvZnjwG0WGHY3f2gpDNXLF4taTC7Pygpf/4hAK3Q7Wf2ee4+lt1/T9K8vCeaWb+k/i7fB0BFSn9B5+6eOsHF3QckDUicCAM0qduht3Ezmy9J2e2p6loCUIduwz4kaX12f72k31XTDoC6FJ7PbmbPSbpD0lxJ45J+JOm/Jf1W0lclHZP0fXe/8ku8ydZ1TR7GL1y4MFk/dOhQsj537txkvcy12Yuuvb579+5k/fHHH0/Wz507l6ynFJ3PXrTd+vr6kvWPP/44t/bYY48lX/vMM88k6+fPn0/Wm5R3PnvhZ3Z3X5dT+napjgD0FD+XBYIg7EAQhB0IgrADQRB2IAguJV2BRYsWJeujo6Ol1l809HbgwIHc2tq1a5OvPX36dFc99cIjjzySrD/99NPJemq7FZ0WfNNNNyXr77zzTrLeJC4lDQRH2IEgCDsQBGEHgiDsQBCEHQiCsANBcCnpaWB4eDhZf+CBB3JrbR5HLzI0NJSs33fffcn6LbfcUmU70x57diAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgnH2Hig6H73IrbfeWlEn04vZpKdlf6pou5bZ7tu3b0/W77///q7X3RT27EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBOPsFXjwwQeT9aJrlGNyd999d7K+dOnSZD213Yv+nRSNs09HhXt2M3vWzE6Z2ZEJy7ab2UkzezX7u6veNgGUNZXD+F9IWjnJ8v9095uzv73VtgWgaoVhd/eDks70oBcANSrzBd3DZvZ6dpg/O+9JZtZvZsNmlr6QGoBadRv2nZK+IelmSWOSfpr3RHcfcPdl7r6sy/cCUIGuwu7u4+5+wd0vSvqZpOXVtgWgal2F3czmT3j4PUlH8p4LoB0Kx9nN7DlJd0iaa2YnJP1I0h1mdrMkl3RU0oYae2y9ovHgyPr6+nJrS5YsSb5269atVbfzqffffz9ZP3/+fG3v3ZTCsLv7ukkW/7yGXgDUiJ/LAkEQdiAIwg4EQdiBIAg7EASnuKJW27Zty61t3Lix1vc+evRobm39+vXJ1x4/frzibprHnh0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgmCcHaXs3Zu+1ujixYt71MnVRkZGcmsvvfRSDztpB/bsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAE4+wVMLNkfcaMcv9PXbVqVdevHRgYSNavu+66rtctFf+zNTldNZf4vhx7diAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgnH2CuzcuTNZf/LJJ0utf8+ePcl6mbHsusfB61z/rl27alv3tahwz25mN5jZATMbMbM3zWxztnyOmb1oZm9lt7PrbxdAt6ZyGP+JpH9z9yWSvilpo5ktkbRF0n53v1HS/uwxgJYqDLu7j7n7K9n9jySNSrpe0mpJg9nTBiWtqatJAOV9ps/sZrZQ0lJJf5Q0z93HstJ7kublvKZfUn/3LQKowpS/jTezL0raLelRd//rxJq7uySf7HXuPuDuy9x9WalOAZQypbCb2efVCfqv3P35bPG4mc3P6vMlnaqnRQBVsM5OOfGEzvmbg5LOuPujE5Y/JekDd3/CzLZImuPu/16wrvSbTVMLFixI1g8dOpSs9/X1JettPo20qLfx8fHc2ujoaPK1/f3pT39jY2PJ+rlz55L1a5W7T3rO9VQ+s39L0v2S3jCzV7NlWyU9Iem3ZvYDScckfb+KRgHUozDs7v6SpLyrM3y72nYA1IWfywJBEHYgCMIOBEHYgSAIOxBE4Th7pW92jY6zF1mxYkWyvmZN+rSCzZs3J+ttHmfftGlTbm3Hjh1VtwPlj7OzZweCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIBhnnwZWrlyZrKfO+y6atnhoaChZL5ryuWi66pGRkdza8ePHk69FdxhnB4Ij7EAQhB0IgrADQRB2IAjCDgRB2IEgGGcHrjGMswPBEXYgCMIOBEHYgSAIOxAEYQeCIOxAEIVhN7MbzOyAmY2Y2Ztmtjlbvt3MTprZq9nfXfW3C6BbhT+qMbP5kua7+ytm9iVJhyWtUWc+9rPu/h9TfjN+VAPULu9HNVOZn31M0lh2/yMzG5V0fbXtAajbZ/rMbmYLJS2V9Mds0cNm9rqZPWtms3Ne029mw2Y2XKpTAKVM+bfxZvZFSf8r6Sfu/ryZzZN0WpJL+rE6h/oPFKyDw3igZnmH8VMKu5l9XtIeSb9396cnqS+UtMfd/7FgPYQdqFnXJ8JY5/KhP5c0OjHo2Rd3l3xP0pGyTQKoz1S+jb9N0v9JekPSpbmBt0paJ+lmdQ7jj0rakH2Zl1oXe3agZqUO46tC2IH6cT47EBxhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgiMILTlbstKRjEx7PzZa1UVt7a2tfEr11q8reFuQVeno++1Vvbjbs7ssaayChrb21tS+J3rrVq944jAeCIOxAEE2HfaDh909pa29t7Uuit271pLdGP7MD6J2m9+wAeoSwA0E0EnYzW2lmfzKzt81sSxM95DGzo2b2RjYNdaPz02Vz6J0ysyMTls0xsxfN7K3sdtI59hrqrRXTeCemGW902zU9/XnPP7Ob2UxJf5b0HUknJL0saZ27j/S0kRxmdlTSMndv/AcYZrZC0llJ/3Vpai0ze1LSGXd/Ivsf5Wx3/2FLetuuzziNd0295U0z/q9qcNtVOf15N5rYsy+X9La7v+vuf5P0G0mrG+ij9dz9oKQzVyxeLWkwuz+ozn8sPZfTWyu4+5i7v5Ld/0jSpWnGG912ib56oomwXy/pLxMen1C75nt3SfvM7LCZ9TfdzCTmTZhm6z1J85psZhKF03j30hXTjLdm23Uz/XlZfEF3tdvc/Z8lrZK0MTtcbSXvfAZr09jpTknfUGcOwDFJP22ymWya8d2SHnX3v06sNbntJumrJ9utibCflHTDhMdfyZa1grufzG5PSXpBnY8dbTJ+aQbd7PZUw/18yt3H3f2Cu1+U9DM1uO2yacZ3S/qVuz+fLW58203WV6+2WxNhf1nSjWb2NTP7gqS1koYa6OMqZjYr++JEZjZL0nfVvqmohyStz+6vl/S7Bnu5TFum8c6bZlwNb7vGpz93957/SbpLnW/k35G0rYkecvr6uqTXsr83m+5N0nPqHNadV+e7jR9I+rKk/ZLekvQHSXNa1Nsv1Zna+3V1gjW/od5uU+cQ/XVJr2Z/dzW97RJ99WS78XNZIAi+oAOCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIP4OyeFugDp7XnMAAAAASUVORK5CYII=\n" + }, + "metadata": { + "needs_background": "light" + } + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model prediction: 0\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANTUlEQVR4nO3db6hc9Z3H8c9nTRsxDZK7wRDSsKlRkBDcVIMoG1alNGYjEotaEsKSVdnbBxVa3AcrKlTUBZFtln1i4Bal6dJNKRpRatnWhriuT0puJKtX77bGEElCTIwhNJFANfnug3siV3PnzM3MOXPOzff9gsvMnO+cmS/HfPydPzPzc0QIwMXvL5puAMBgEHYgCcIOJEHYgSQIO5DErEG+mW1O/QM1iwhPtbyvkd32Gtt/sL3X9kP9vBaAernX6+y2L5H0R0nflnRQ0i5JGyLi3ZJ1GNmBmtUxst8gaW9E7IuIP0v6haR1fbwegBr1E/ZFkg5MenywWPYFtodtj9oe7eO9APSp9hN0ETEiaURiNx5oUj8j+yFJiyc9/nqxDEAL9RP2XZKutv0N21+VtF7Sy9W0BaBqPe/GR8Rnth+Q9BtJl0h6LiLeqawzAJXq+dJbT2/GMTtQu1o+VANg5iDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgiZ6nbMbgXHfddaX17du3d6wtWbKk4m7aY/Xq1aX18fHxjrUDBw5U3U7r9RV22/slnZR0RtJnEbGyiqYAVK+Kkf3WiDhWwesAqBHH7EAS/YY9JP3W9m7bw1M9wfaw7VHbo32+F4A+9LsbvyoiDtm+QtKrtv8vIl6f/ISIGJE0Ikm2o8/3A9Cjvkb2iDhU3B6V9KKkG6poCkD1eg677Tm25567L2m1pLGqGgNQrX524xdIetH2udf5z4j4r0q6whfcdtttpfXZs2cPqJN2ueOOO0rr9913X8fa+vXrq26n9XoOe0Tsk/TXFfYCoEZcegOSIOxAEoQdSIKwA0kQdiAJvuLaArNmlf9nWLt27YA6mVl2795dWn/wwQc71ubMmVO67ieffNJTT23GyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSXCdvQVuvfXW0vpNN91UWn/66aerbGfGmDdvXml92bJlHWuXXXZZ6bpcZwcwYxF2IAnCDiRB2IEkCDuQBGEHkiDsQBKOGNwkLVlnhFm+fHlp/bXXXiutf/zxx6X166+/vmPt1KlTpevOZN2226pVqzrWFi5cWLruRx991EtLrRARnmo5IzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMH32Qfg0UcfLa13+w3zNWvWlNYv1mvpQ0NDpfWbb765tH727Nkq25nxuo7stp+zfdT22KRlQ7Zftf1ecVv+KwIAGjed3fifSvry0PKQpB0RcbWkHcVjAC3WNewR8bqk419avE7S1uL+Vkl3VtwXgIr1esy+ICIOF/c/lLSg0xNtD0sa7vF9AFSk7xN0ERFlX3CJiBFJI1LeL8IAbdDrpbcjthdKUnF7tLqWANSh17C/LGlTcX+TpJeqaQdAXbruxtveJukWSfNtH5T0I0lPSfql7fslfSDpu3U22XZ33313ab3b/Op79+4trY+Ojl5wTxeDRx55pLTe7Tp62ffdT5w40UtLM1rXsEfEhg6lb1XcC4Aa8XFZIAnCDiRB2IEkCDuQBGEHkuArrhW45557Suvdpgd+5plnqmxnxliyZElpfePGjaX1M2fOlNaffPLJjrVPP/20dN2LESM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBdfZpuvzyyzvWbrzxxr5ee8uWLX2tP1MND5f/Wtn8+fNL6+Pj46X1nTt3XnBPFzNGdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1Iguvs0zR79uyOtUWLFpWuu23btqrbuSgsXbq0r/XHxsa6PwmfY2QHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSS4zj5NJ0+e7Fjbs2dP6brXXnttaX1oaKi0fvz48dJ6m11xxRUda92muu7mjTfe6Gv9bLqO7Lafs33U9tikZY/ZPmR7T/FXPgE5gMZNZzf+p5LWTLH83yJiRfH362rbAlC1rmGPiNclzdz9SACS+jtB94Dtt4rd/HmdnmR72Pao7dE+3gtAn3oN+xZJSyWtkHRY0o87PTEiRiJiZUSs7PG9AFSgp7BHxJGIOBMRZyX9RNIN1bYFoGo9hd32wkkPvyOJ7xoCLdf1OrvtbZJukTTf9kFJP5J0i+0VkkLSfknfq7HHVjh9+nTH2vvvv1+67l133VVaf+WVV0rrmzdvLq3Xafny5aX1K6+8srReNgd7RPTS0ufOnj3b1/rZdA17RGyYYvGzNfQCoEZ8XBZIgrADSRB2IAnCDiRB2IEk3O/ljwt6M3twbzZA11xzTWn98ccfL63ffvvtpfWyn7Gu27Fjx0rr3f79lE27bLunns6ZO3duab3scunFLCKm3LCM7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBNfZW2DFihWl9auuumpAnZzv+eef72v9rVu3dqxt3Lixr9eeNYtfQp8K19mB5Ag7kARhB5Ig7EAShB1IgrADSRB2IAkuVLZAtymfu9XbbN++fbW9drefuR4bYzqDyRjZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJrrOjVmW/Dd/v78ZzHf3CdB3ZbS+2vdP2u7bfsf2DYvmQ7Vdtv1fczqu/XQC9ms5u/GeS/ikilkm6UdL3bS+T9JCkHRFxtaQdxWMALdU17BFxOCLeLO6flDQuaZGkdZLO/ebQVkl31tUkgP5d0DG77SWSvinp95IWRMThovShpAUd1hmWNNx7iwCqMO2z8ba/JukFST+MiD9NrsXEr1ZO+WOSETESESsjYmVfnQLoy7TCbvsrmgj6zyNie7H4iO2FRX2hpKP1tAigCtM5G29Jz0oaj4jNk0ovS9pU3N8k6aXq28NMFxG1/eHCTOeY/W8k/b2kt22f+2L1w5KekvRL2/dL+kDSd+tpEUAVuoY9It6Q1OnTD9+qth0AdeHjskAShB1IgrADSRB2IAnCDiTBV1xRq0svvbTndU+fPl1hJ2BkB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkuM6OWt17770daydOnChd94knnqi6ndQY2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCa6zo1a7du3qWNu8eXPHmiTt3Lmz6nZSY2QHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSTcbZ5r24sl/UzSAkkhaSQi/t32Y5L+UdJHxVMfjohfd3ktJtUGahYRU866PJ2wL5S0MCLetD1X0m5Jd2piPvZTEfGv022CsAP16xT26czPfljS4eL+SdvjkhZV2x6Aul3QMbvtJZK+Ken3xaIHbL9l+znb8zqsM2x71PZoX50C6EvX3fjPn2h/TdJ/S/qXiNhue4GkY5o4jn9CE7v693V5DXbjgZr1fMwuSba/IulXkn4TEed9e6EY8X8VEcu7vA5hB2rWKexdd+NtW9KzksYnB704cXfOdySN9dskgPpM52z8Kkn/I+ltSWeLxQ9L2iBphSZ24/dL+l5xMq/stRjZgZr1tRtfFcIO1K/n3XgAFwfCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoOesvmYpA8mPZ5fLGujtvbW1r4keutVlb39VafCQL/Pft6b26MRsbKxBkq0tbe29iXRW68G1Ru78UAShB1IoumwjzT8/mXa2ltb+5LorVcD6a3RY3YAg9P0yA5gQAg7kEQjYbe9xvYfbO+1/VATPXRie7/tt23vaXp+umIOvaO2xyYtG7L9qu33itsp59hrqLfHbB8qtt0e22sb6m2x7Z2237X9ju0fFMsb3XYlfQ1kuw38mN32JZL+KOnbkg5K2iVpQ0S8O9BGOrC9X9LKiGj8Axi2/1bSKUk/Oze1lu2nJR2PiKeK/1HOi4h/bklvj+kCp/GuqbdO04z/gxrcdlVOf96LJkb2GyTtjYh9EfFnSb+QtK6BPlovIl6XdPxLi9dJ2lrc36qJfywD16G3VoiIwxHxZnH/pKRz04w3uu1K+hqIJsK+SNKBSY8Pql3zvYek39rebXu46WamsGDSNFsfSlrQZDNT6DqN9yB9aZrx1my7XqY/7xcn6M63KiKuk/R3kr5f7K62Ukwcg7Xp2ukWSUs1MQfgYUk/brKZYprxFyT9MCL+NLnW5Laboq+BbLcmwn5I0uJJj79eLGuFiDhU3B6V9KImDjva5Mi5GXSL26MN9/O5iDgSEWci4qykn6jBbVdMM/6CpJ9HxPZicePbbqq+BrXdmgj7LklX2/6G7a9KWi/p5Qb6OI/tOcWJE9meI2m12jcV9cuSNhX3N0l6qcFevqAt03h3mmZcDW+7xqc/j4iB/0laq4kz8u9LeqSJHjr0daWk/y3+3mm6N0nbNLFb96kmzm3cL+kvJe2Q9J6k30kaalFv/6GJqb3f0kSwFjbU2ypN7KK/JWlP8be26W1X0tdAthsflwWS4AQdkARhB5Ig7EAShB1IgrADSRB2IAnCDiTx/wSyThk1bZlLAAAAAElFTkSuQmCC\n" + }, + "metadata": { + "needs_background": "light" + } + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model prediction: 4\n" + ] + } + ], + "source": [ + "# Predict 5 images from validation set.\n", + "n_images = 5\n", + "test_images = x_test[:n_images]\n", + "predictions = conv_net(test_images)\n", + "\n", + "# Display image and model prediction.\n", + "for i in range(n_images):\n", + " plt.imshow(np.reshape(test_images[i], [28, 28]), cmap='gray')\n", + " plt.show()\n", + " print(\"Model prediction: %i\" % np.argmax(predictions.numpy()[i]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + }, + "colab": { + "name": "convolutional_network.ipynb", + "provenance": [] } - ], - "source": [ - "# Predict 5 images from validation set.\n", - "n_images = 5\n", - "test_images = x_test[:n_images]\n", - "predictions = conv_net(test_images)\n", - "\n", - "# Display image and model prediction.\n", - "for i in range(n_images):\n", - " plt.imshow(np.reshape(test_images[i], [28, 28]), cmap='gray')\n", - " plt.show()\n", - " print(\"Model prediction: %i\" % np.argmax(predictions.numpy()[i]))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 2", - "language": "python", - "name": "python2" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.15" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file