From 9927788e35df754476096f28af806643ae04e4d7 Mon Sep 17 00:00:00 2001 From: Christopher Hu <40853270+christopher192@users.noreply.github.com> Date: Wed, 15 Nov 2023 23:42:13 +0800 Subject: [PATCH] update --- 04-food-classification-cnn-pytorch.ipynb | 343 +++++++++ ...4-food-classification-cnn-tensorflow.ipynb | 0 ... => 04-mnist-computer-vision-pytorch.ipynb | 676 +++++++++++++++++- ...d-classification-transfer-learning-1.ipynb | 2 +- ...d-classification-transfer-learning-2.ipynb | 2 +- ...d-classification-transfer-learning-3.ipynb | 2 +- ...ral-language-processing-introduction.ipynb | 2 +- 7 files changed, 1006 insertions(+), 21 deletions(-) create mode 100644 04-food-classification-cnn-pytorch.ipynb rename food-classification-cnn-tensorflow.ipynb => 04-food-classification-cnn-tensorflow.ipynb (100%) rename mnist-computer-vision-pytorch.ipynb => 04-mnist-computer-vision-pytorch.ipynb (82%) rename food-classification-transfer-learning-1.ipynb => 05-food-classification-transfer-learning-1.ipynb (99%) rename food-classification-transfer-learning-2.ipynb => 05-food-classification-transfer-learning-2.ipynb (99%) rename food-classification-transfer-learning-3.ipynb => 05-food-classification-transfer-learning-3.ipynb (99%) rename natural-language-processing-introduction.ipynb => 06-natural-language-processing-introduction.ipynb (99%) diff --git a/04-food-classification-cnn-pytorch.ipynb b/04-food-classification-cnn-pytorch.ipynb new file mode 100644 index 0000000..a2c7f5e --- /dev/null +++ b/04-food-classification-cnn-pytorch.ipynb @@ -0,0 +1,343 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "464ba769", + "metadata": {}, + "source": [ + "## Convolutional Neural Networks and Computer Vision with PyTorch" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "b831898b", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pathlib\n", + "import torch\n", + "from torch import nn\n", + "from torch.utils.data import DataLoader, Dataset\n", + "from torchvision import datasets, transforms" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "0cda9a43", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2.1.0+cpu'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "f8298624", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'cpu'" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "device" + ] + }, + { + "cell_type": "markdown", + "id": "32d64866", + "metadata": {}, + "source": [ + "### 1. Data preparation" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "96999cfb", + "metadata": {}, + "outputs": [], + "source": [ + "train_dir, test_dir = \"data/pizza_steak_sushi/train\", \"data/pizza_steak_sushi/test\"" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "d912a4a7", + "metadata": {}, + "outputs": [], + "source": [ + "data_transform = transforms.Compose([\n", + " transforms.Resize(size = (64, 64)),\n", + " transforms.RandomHorizontalFlip(p = 0.5),\n", + " transforms.ToTensor()\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "66ce4314", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train data:\n", + "Dataset ImageFolder\n", + " Number of datapoints: 225\n", + " Root location: data/pizza_steak_sushi/train\n", + " StandardTransform\n", + "Transform: Compose(\n", + " Resize(size=(64, 64), interpolation=bilinear, max_size=None, antialias=warn)\n", + " RandomHorizontalFlip(p=0.5)\n", + " ToTensor()\n", + " )\n", + "Test data:\n", + "Dataset ImageFolder\n", + " Number of datapoints: 75\n", + " Root location: data/pizza_steak_sushi/test\n", + " StandardTransform\n", + "Transform: Compose(\n", + " Resize(size=(64, 64), interpolation=bilinear, max_size=None, antialias=warn)\n", + " RandomHorizontalFlip(p=0.5)\n", + " ToTensor()\n", + " )\n" + ] + } + ], + "source": [ + "train_data = datasets.ImageFolder(root = train_dir,\n", + " transform = data_transform,\n", + " target_transform = None)\n", + "\n", + "test_data = datasets.ImageFolder(root = test_dir, \n", + " transform = data_transform)\n", + "\n", + "print(f\"Train data:\\n{train_data}\\nTest data:\\n{test_data}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "a1224c74", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['pizza', 'steak', 'sushi']" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class_name = train_data.classes\n", + "class_name" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "1a3df960", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'pizza': 0, 'steak': 1, 'sushi': 2}" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class_dict = train_data.class_to_idx\n", + "class_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "2d203c6c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(225, 75)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(train_data), len(test_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "71c80d09", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image tensor:\n", + "tensor([[[0.1137, 0.1020, 0.0980, ..., 0.1255, 0.1216, 0.1176],\n", + " [0.1059, 0.0980, 0.0980, ..., 0.1294, 0.1294, 0.1294],\n", + " [0.1020, 0.0980, 0.0941, ..., 0.1333, 0.1333, 0.1333],\n", + " ...,\n", + " [0.1098, 0.1098, 0.1255, ..., 0.1686, 0.1647, 0.1686],\n", + " [0.0902, 0.0941, 0.1098, ..., 0.1686, 0.1647, 0.1686],\n", + " [0.0863, 0.0863, 0.0980, ..., 0.1686, 0.1647, 0.1647]],\n", + "\n", + " [[0.0745, 0.0706, 0.0745, ..., 0.0588, 0.0588, 0.0588],\n", + " [0.0745, 0.0706, 0.0745, ..., 0.0627, 0.0627, 0.0627],\n", + " [0.0706, 0.0745, 0.0745, ..., 0.0706, 0.0706, 0.0706],\n", + " ...,\n", + " [0.1255, 0.1333, 0.1373, ..., 0.2510, 0.2392, 0.2392],\n", + " [0.1098, 0.1176, 0.1255, ..., 0.2510, 0.2392, 0.2314],\n", + " [0.1020, 0.1059, 0.1137, ..., 0.2431, 0.2353, 0.2275]],\n", + "\n", + " [[0.0941, 0.0902, 0.0902, ..., 0.0157, 0.0196, 0.0196],\n", + " [0.0902, 0.0863, 0.0902, ..., 0.0196, 0.0157, 0.0196],\n", + " [0.0902, 0.0902, 0.0902, ..., 0.0157, 0.0157, 0.0196],\n", + " ...,\n", + " [0.1294, 0.1333, 0.1490, ..., 0.1961, 0.1882, 0.1843],\n", + " [0.1098, 0.1137, 0.1255, ..., 0.1922, 0.1843, 0.1804],\n", + " [0.1059, 0.0980, 0.1059, ..., 0.1882, 0.1804, 0.1765]]])\n", + "Image shape: torch.Size([3, 64, 64])\n", + "Image datatype: torch.float32\n", + "Image label: 0\n", + "Label datatype: \n" + ] + } + ], + "source": [ + "image, label = train_data[0][0], train_data[0][1]\n", + "\n", + "print(f\"Image tensor:\\n{image}\")\n", + "print(f\"Image shape: {image.shape}\")\n", + "print(f\"Image datatype: {image.dtype}\")\n", + "print(f\"Image label: {label}\")\n", + "print(f\"Label datatype: {type(label)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "fd927774", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(,\n", + " )" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataloader = DataLoader(dataset = train_data, batch_size = 1,\n", + " num_workers = 1, shuffle = True)\n", + "\n", + "test_dataloader = DataLoader(dataset = test_data, batch_size = 1, \n", + " num_workers = 1, shuffle = False)\n", + "\n", + "train_dataloader, test_dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "b8fcacf6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image shape: torch.Size([1, 3, 64, 64]) -> [batch_size, color_channels, height, width]\n", + "Label shape: torch.Size([1])\n" + ] + } + ], + "source": [ + "image, label = next(iter(train_dataloader))\n", + "\n", + "print(f\"Image shape: {image.shape} -> [batch_size, color_channels, height, width]\")\n", + "print(f\"Label shape: {label.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65b95b07", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/food-classification-cnn-tensorflow.ipynb b/04-food-classification-cnn-tensorflow.ipynb similarity index 100% rename from food-classification-cnn-tensorflow.ipynb rename to 04-food-classification-cnn-tensorflow.ipynb diff --git a/mnist-computer-vision-pytorch.ipynb b/04-mnist-computer-vision-pytorch.ipynb similarity index 82% rename from mnist-computer-vision-pytorch.ipynb rename to 04-mnist-computer-vision-pytorch.ipynb index ec97865..2899fc8 100644 --- a/mnist-computer-vision-pytorch.ipynb +++ b/04-mnist-computer-vision-pytorch.ipynb @@ -216,7 +216,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Dataloaders: (, )\n", + "Dataloaders: (, )\n", "Length of train dataloader: 1875 batches of 32\n", "Length of test dataloader: 313 batches of 32\n" ] @@ -489,11 +489,11 @@ " \n", " def forward(self, x: torch.Tensor):\n", " x = self.first_block(x)\n", - " print(x.shape)\n", + "# print(x.shape)\n", " x = self.second_block(x)\n", - " print(x.shape) \n", + "# print(x.shape) \n", " x = self.classifier(x)\n", - " print(x.shape)\n", + "# print(x.shape)\n", " \n", " return x" ] @@ -664,7 +664,7 @@ "output_type": "stream", "text": [ "\r", - " 25%|██▌ | 1/4 [00:08<00:24, 8.16s/it]" + " 25%|██▌ | 1/4 [00:08<00:25, 8.49s/it]" ] }, { @@ -688,7 +688,7 @@ "output_type": "stream", "text": [ "\r", - " 50%|█████ | 2/4 [00:16<00:16, 8.43s/it]" + " 50%|█████ | 2/4 [00:16<00:16, 8.36s/it]" ] }, { @@ -712,7 +712,7 @@ "output_type": "stream", "text": [ "\r", - " 75%|███████▌ | 3/4 [00:25<00:08, 8.37s/it]" + " 75%|███████▌ | 3/4 [00:24<00:08, 8.23s/it]" ] }, { @@ -735,7 +735,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 4/4 [00:32<00:00, 8.23s/it]" + "100%|██████████| 4/4 [00:33<00:00, 8.31s/it]" ] }, { @@ -745,7 +745,7 @@ "\n", "Train loss: 0.44251 | Test loss: 0.46306, Test accuracy: 83.75%\n", "\n", - "Train time on cpu: 32.934 seconds\n" + "Train time on cpu: 33.231 seconds\n" ] }, { @@ -953,7 +953,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0\n", + "Epoch: 1\n", "---------\n", "Train loss: 1.05878 | Train accuracy: 62.19%\n" ] @@ -963,7 +963,7 @@ "output_type": "stream", "text": [ "\r", - " 33%|███▎ | 1/3 [00:08<00:16, 8.29s/it]" + " 33%|███▎ | 1/3 [00:08<00:16, 8.16s/it]" ] }, { @@ -972,7 +972,7 @@ "text": [ "Test loss: 0.99506 | Test accuracy: 64.13%\n", "\n", - "Epoch: 1\n", + "Epoch: 2\n", "---------\n", "Train loss: 0.91331 | Train accuracy: 66.40%\n" ] @@ -982,7 +982,7 @@ "output_type": "stream", "text": [ "\r", - " 67%|██████▋ | 2/3 [00:16<00:08, 8.11s/it]" + " 67%|██████▋ | 2/3 [00:16<00:08, 8.14s/it]" ] }, { @@ -991,7 +991,7 @@ "text": [ "Test loss: 0.90508 | Test accuracy: 66.86%\n", "\n", - "Epoch: 2\n", + "Epoch: 3\n", "---------\n", "Train loss: 0.87981 | Train accuracy: 67.25%\n" ] @@ -1000,7 +1000,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 3/3 [00:24<00:00, 8.12s/it]" + "100%|██████████| 3/3 [00:24<00:00, 8.23s/it]" ] }, { @@ -1009,7 +1009,7 @@ "text": [ "Test loss: 0.89485 | Test accuracy: 66.47%\n", "\n", - "Train time on cpu: 24.364 seconds\n" + "Train time on cpu: 24.686 seconds\n" ] }, { @@ -1101,9 +1101,651 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "id": "f6ac8d33", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image batch shape: torch.Size([32, 3, 64, 64]) -> [batch_size, color_channels, height, width]\n", + "Single image shape: torch.Size([3, 64, 64]) -> [color_channels, height, width]\n", + "Single image pixel values:\n", + "tensor([[[ 1.9269, 1.4873, 0.9007, ..., 1.8446, -1.1845, 1.3835],\n", + " [ 1.4451, 0.8564, 2.2181, ..., 0.3399, 0.7200, 0.4114],\n", + " [ 1.9312, 1.0119, -1.4364, ..., -0.5558, 0.7043, 0.7099],\n", + " ...,\n", + " [-0.5610, -0.4830, 0.4770, ..., -0.2713, -0.9537, -0.6737],\n", + " [ 0.3076, -0.1277, 0.0366, ..., -2.0060, 0.2824, -0.8111],\n", + " [-1.5486, 0.0485, -0.7712, ..., -0.1403, 0.9416, -0.0118]],\n", + "\n", + " [[-0.5197, 1.8524, 1.8365, ..., 0.8935, -1.5114, -0.8515],\n", + " [ 2.0818, 1.0677, -1.4277, ..., 1.6612, -2.6223, -0.4319],\n", + " [-0.1010, -0.4388, -1.9775, ..., 0.2106, 0.2536, -0.7318],\n", + " ...,\n", + " [ 0.2779, 0.7342, -0.3736, ..., -0.4601, 0.1815, 0.1850],\n", + " [ 0.7205, -0.2833, 0.0937, ..., -0.1002, -2.3609, 2.2465],\n", + " [-1.3242, -0.1973, 0.2920, ..., 0.5409, 0.6940, 1.8563]],\n", + "\n", + " [[-0.7978, 1.0261, 1.1465, ..., 1.2134, 0.9354, -0.0780],\n", + " [-1.4647, -1.9571, 0.1017, ..., -1.9986, -0.7409, 0.7011],\n", + " [-1.3938, 0.8466, -1.7191, ..., -1.1867, 0.1320, 0.3407],\n", + " ...,\n", + " [ 0.8206, -0.3745, 1.2499, ..., -0.0676, 0.0385, 0.6335],\n", + " [-0.5589, -0.3393, 0.2347, ..., 2.1181, 2.4569, 1.3083],\n", + " [-0.4092, 1.5199, 0.2401, ..., -0.2558, 0.7870, 0.9924]]])\n" + ] + } + ], + "source": [ + "torch.manual_seed(42)\n", + "\n", + "random_image = torch.randn(size = (32, 3, 64, 64))\n", + "test_image = random_image[0]\n", + "\n", + "print(f\"Image batch shape: {random_image.shape} -> [batch_size, color_channels, height, width]\")\n", + "print(f\"Single image shape: {test_image.shape} -> [color_channels, height, width]\") \n", + "print(f\"Single image pixel values:\\n{test_image}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "9f20da48", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 1.5396, 0.0516, 0.6454, ..., -0.3673, 0.8711, 0.4256],\n", + " [ 0.3662, 1.0114, -0.5997, ..., 0.8983, 0.2809, -0.2741],\n", + " [ 1.2664, -1.4054, 0.3727, ..., -0.3409, 1.2191, -0.0463],\n", + " ...,\n", + " [-0.1541, 0.5132, -0.3624, ..., -0.2360, -0.4609, -0.0035],\n", + " [ 0.2981, -0.2432, 1.5012, ..., -0.6289, -0.7283, -0.5767],\n", + " [-0.0386, -0.0781, -0.0388, ..., 0.2842, 0.4228, -0.1802]],\n", + "\n", + " [[-0.2840, -0.0319, -0.4455, ..., -0.7956, 1.5599, -1.2449],\n", + " [ 0.2753, -0.1262, -0.6541, ..., -0.2211, 0.1999, -0.8856],\n", + " [-0.5404, -1.5489, 0.0249, ..., -0.5932, -1.0913, -0.3849],\n", + " ...,\n", + " [ 0.3870, -0.4064, -0.8236, ..., 0.1734, -0.4330, -0.4951],\n", + " [-0.1984, -0.6386, 1.0263, ..., -0.9401, -0.0585, -0.7833],\n", + " [-0.6306, -0.2052, -0.3694, ..., -1.3248, 0.2456, -0.7134]],\n", + "\n", + " [[ 0.4414, 0.5100, 0.4846, ..., -0.8484, 0.2638, 1.1258],\n", + " [ 0.8117, 0.3191, -0.0157, ..., 1.2686, 0.2319, 0.5003],\n", + " [ 0.3212, 0.0485, -0.2581, ..., 0.2258, 0.2587, -0.8804],\n", + " ...,\n", + " [-0.1144, -0.1869, 0.0160, ..., -0.8346, 0.0974, 0.8421],\n", + " [ 0.2941, 0.4417, 0.5866, ..., -0.1224, 0.4814, -0.4799],\n", + " [ 0.6059, -0.0415, -0.2028, ..., 0.1170, 0.2521, -0.4372]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.2560, -0.0477, 0.6380, ..., 0.6436, 0.7553, -0.7055],\n", + " [ 1.5595, -0.2209, -0.9486, ..., -0.4876, 0.7754, 0.0750],\n", + " [-0.0797, 0.2471, 1.1300, ..., 0.1505, 0.2354, 0.9576],\n", + " ...,\n", + " [ 1.1065, 0.6839, 1.2183, ..., 0.3015, -0.1910, -0.1902],\n", + " [-0.3486, -0.7173, -0.3582, ..., 0.4917, 0.7219, 0.1513],\n", + " [ 0.0119, 0.1017, 0.7839, ..., -0.3752, -0.8127, -0.1257]],\n", + "\n", + " [[ 0.3841, 1.1322, 0.1620, ..., 0.7010, 0.0109, 0.6058],\n", + " [ 0.1664, 0.1873, 1.5924, ..., 0.3733, 0.9096, -0.5399],\n", + " [ 0.4094, -0.0861, -0.7935, ..., -0.1285, -0.9932, -0.3013],\n", + " ...,\n", + " [ 0.2688, -0.5630, -1.1902, ..., 0.4493, 0.5404, -0.0103],\n", + " [ 0.0535, 0.4411, 0.5313, ..., 0.0148, -1.0056, 0.3759],\n", + " [ 0.3031, -0.1590, -0.1316, ..., -0.5384, -0.4271, -0.4876]],\n", + "\n", + " [[-1.1865, -0.7280, -1.2331, ..., -0.9013, -0.0542, -1.5949],\n", + " [-0.6345, -0.5920, 0.5326, ..., -1.0395, -0.7963, -0.0647],\n", + " [-0.1132, 0.5166, 0.2569, ..., 0.5595, -1.6881, 0.9485],\n", + " ...,\n", + " [-0.0254, -0.2669, 0.1927, ..., -0.2917, 0.1088, -0.4807],\n", + " [-0.2609, -0.2328, 0.1404, ..., -0.1325, -0.8436, -0.7524],\n", + " [-1.1399, -0.1751, -0.8705, ..., 0.1589, 0.3377, 0.3493]]],\n", + " grad_fn=)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.manual_seed(42)\n", + "\n", + "cnn_layer = nn.Conv2d(in_channels = 3, out_channels = 10,\n", + " kernel_size = 3, stride = 1, padding = 0)\n", + "cnn_layer(test_image)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "67d8be53", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 3, 64, 64])" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Add extra dimension\n", + "test_image.unsqueeze(dim = 0).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "eabd0c21", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 10, 62, 62])" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cnn_layer(test_image.unsqueeze(dim = 0)).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "ed9df8ab", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Before - torch.Size([1, 10, 62, 62])\n", + "After - torch.Size([1, 10, 30, 30])\n" + ] + } + ], + "source": [ + "torch.manual_seed(42)\n", + "\n", + "cnn_layer_2 = nn.Conv2d(in_channels = 3, out_channels = 10, kernel_size = (5, 5),\n", + " stride = 2, padding = 0)\n", + "\n", + "print(\"Before - \", cnn_layer(test_image.unsqueeze(dim = 0)).shape)\n", + "print(\"After - \", cnn_layer_2(test_image.unsqueeze(dim = 0)).shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "0399ddcd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OrderedDict([('weight', tensor([[[[ 0.0883, 0.0958, -0.0271, 0.1061, -0.0253],\n", + " [ 0.0233, -0.0562, 0.0678, 0.1018, -0.0847],\n", + " [ 0.1004, 0.0216, 0.0853, 0.0156, 0.0557],\n", + " [-0.0163, 0.0890, 0.0171, -0.0539, 0.0294],\n", + " [-0.0532, -0.0135, -0.0469, 0.0766, -0.0911]],\n", + "\n", + " [[-0.0532, -0.0326, -0.0694, 0.0109, -0.1140],\n", + " [ 0.1043, -0.0981, 0.0891, 0.0192, -0.0375],\n", + " [ 0.0714, 0.0180, 0.0933, 0.0126, -0.0364],\n", + " [ 0.0310, -0.0313, 0.0486, 0.1031, 0.0667],\n", + " [-0.0505, 0.0667, 0.0207, 0.0586, -0.0704]],\n", + "\n", + " [[-0.1143, -0.0446, -0.0886, 0.0947, 0.0333],\n", + " [ 0.0478, 0.0365, -0.0020, 0.0904, -0.0820],\n", + " [ 0.0073, -0.0788, 0.0356, -0.0398, 0.0354],\n", + " [-0.0241, 0.0958, -0.0684, -0.0689, -0.0689],\n", + " [ 0.1039, 0.0385, 0.1111, -0.0953, -0.1145]]],\n", + "\n", + "\n", + " [[[-0.0903, -0.0777, 0.0468, 0.0413, 0.0959],\n", + " [-0.0596, -0.0787, 0.0613, -0.0467, 0.0701],\n", + " [-0.0274, 0.0661, -0.0897, -0.0583, 0.0352],\n", + " [ 0.0244, -0.0294, 0.0688, 0.0785, -0.0837],\n", + " [-0.0616, 0.1057, -0.0390, -0.0409, -0.1117]],\n", + "\n", + " [[-0.0661, 0.0288, -0.0152, -0.0838, 0.0027],\n", + " [-0.0789, -0.0980, -0.0636, -0.1011, -0.0735],\n", + " [ 0.1154, 0.0218, 0.0356, -0.1077, -0.0758],\n", + " [-0.0384, 0.0181, -0.1016, -0.0498, -0.0691],\n", + " [ 0.0003, -0.0430, -0.0080, -0.0782, -0.0793]],\n", + "\n", + " [[-0.0674, -0.0395, -0.0911, 0.0968, -0.0229],\n", + " [ 0.0994, 0.0360, -0.0978, 0.0799, -0.0318],\n", + " [-0.0443, -0.0958, -0.1148, 0.0330, -0.0252],\n", + " [ 0.0450, -0.0948, 0.0857, -0.0848, -0.0199],\n", + " [ 0.0241, 0.0596, 0.0932, 0.1052, -0.0916]]],\n", + "\n", + "\n", + " [[[ 0.0291, -0.0497, -0.0127, -0.0864, 0.1052],\n", + " [-0.0847, 0.0617, 0.0406, 0.0375, -0.0624],\n", + " [ 0.1050, 0.0254, 0.0149, -0.1018, 0.0485],\n", + " [-0.0173, -0.0529, 0.0992, 0.0257, -0.0639],\n", + " [-0.0584, -0.0055, 0.0645, -0.0295, -0.0659]],\n", + "\n", + " [[-0.0395, -0.0863, 0.0412, 0.0894, -0.1087],\n", + " [ 0.0268, 0.0597, 0.0209, -0.0411, 0.0603],\n", + " [ 0.0607, 0.0432, -0.0203, -0.0306, 0.0124],\n", + " [-0.0204, -0.0344, 0.0738, 0.0992, -0.0114],\n", + " [-0.0259, 0.0017, -0.0069, 0.0278, 0.0324]],\n", + "\n", + " [[-0.1049, -0.0426, 0.0972, 0.0450, -0.0057],\n", + " [-0.0696, -0.0706, -0.1034, -0.0376, 0.0390],\n", + " [ 0.0736, 0.0533, -0.1021, -0.0694, -0.0182],\n", + " [ 0.1117, 0.0167, -0.0299, 0.0478, -0.0440],\n", + " [-0.0747, 0.0843, -0.0525, -0.0231, -0.1149]]],\n", + "\n", + "\n", + " [[[ 0.0773, 0.0875, 0.0421, -0.0805, -0.1140],\n", + " [-0.0938, 0.0861, 0.0554, 0.0972, 0.0605],\n", + " [ 0.0292, -0.0011, -0.0878, -0.0989, -0.1080],\n", + " [ 0.0473, -0.0567, -0.0232, -0.0665, -0.0210],\n", + " [-0.0813, -0.0754, 0.0383, -0.0343, 0.0713]],\n", + "\n", + " [[-0.0370, -0.0847, -0.0204, -0.0560, -0.0353],\n", + " [-0.1099, 0.0646, -0.0804, 0.0580, 0.0524],\n", + " [ 0.0825, -0.0886, 0.0830, -0.0546, 0.0428],\n", + " [ 0.1084, -0.0163, -0.0009, -0.0266, -0.0964],\n", + " [ 0.0554, -0.1146, 0.0717, 0.0864, 0.1092]],\n", + "\n", + " [[-0.0272, -0.0949, 0.0260, 0.0638, -0.1149],\n", + " [-0.0262, -0.0692, -0.0101, -0.0568, -0.0472],\n", + " [-0.0367, -0.1097, 0.0947, 0.0968, -0.0181],\n", + " [-0.0131, -0.0471, -0.1043, -0.1124, 0.0429],\n", + " [-0.0634, -0.0742, -0.0090, -0.0385, -0.0374]]],\n", + "\n", + "\n", + " [[[ 0.0037, -0.0245, -0.0398, -0.0553, -0.0940],\n", + " [ 0.0968, -0.0462, 0.0306, -0.0401, 0.0094],\n", + " [ 0.1077, 0.0532, -0.1001, 0.0458, 0.1096],\n", + " [ 0.0304, 0.0774, 0.1138, -0.0177, 0.0240],\n", + " [-0.0803, -0.0238, 0.0855, 0.0592, -0.0731]],\n", + "\n", + " [[-0.0926, -0.0789, -0.1140, -0.0891, -0.0286],\n", + " [ 0.0779, 0.0193, -0.0878, -0.0926, 0.0574],\n", + " [-0.0859, -0.0142, 0.0554, -0.0534, -0.0126],\n", + " [-0.0101, -0.0273, -0.0585, -0.1029, -0.0933],\n", + " [-0.0618, 0.1115, -0.0558, -0.0775, 0.0280]],\n", + "\n", + " [[ 0.0318, 0.0633, 0.0878, 0.0643, -0.1145],\n", + " [ 0.0102, 0.0699, -0.0107, -0.0680, 0.1101],\n", + " [-0.0432, -0.0657, -0.1041, 0.0052, 0.0512],\n", + " [ 0.0256, 0.0228, -0.0876, -0.1078, 0.0020],\n", + " [ 0.1053, 0.0666, -0.0672, -0.0150, -0.0851]]],\n", + "\n", + "\n", + " [[[-0.0557, 0.0209, 0.0629, 0.0957, -0.1060],\n", + " [ 0.0772, -0.0814, 0.0432, 0.0977, 0.0016],\n", + " [ 0.1051, -0.0984, -0.0441, 0.0673, -0.0252],\n", + " [-0.0236, -0.0481, 0.0796, 0.0566, 0.0370],\n", + " [-0.0649, -0.0937, 0.0125, 0.0342, -0.0533]],\n", + "\n", + " [[-0.0323, 0.0780, 0.0092, 0.0052, -0.0284],\n", + " [-0.1046, -0.1086, -0.0552, -0.0587, 0.0360],\n", + " [-0.0336, -0.0452, 0.1101, 0.0402, 0.0823],\n", + " [-0.0559, -0.0472, 0.0424, -0.0769, -0.0755],\n", + " [-0.0056, -0.0422, -0.0866, 0.0685, 0.0929]],\n", + "\n", + " [[ 0.0187, -0.0201, -0.1070, -0.0421, 0.0294],\n", + " [ 0.0544, -0.0146, -0.0457, 0.0643, -0.0920],\n", + " [ 0.0730, -0.0448, 0.0018, -0.0228, 0.0140],\n", + " [-0.0349, 0.0840, -0.0030, 0.0901, 0.1110],\n", + " [-0.0563, -0.0842, 0.0926, 0.0905, -0.0882]]],\n", + "\n", + "\n", + " [[[-0.0089, -0.1139, -0.0945, 0.0223, 0.0307],\n", + " [ 0.0245, -0.0314, 0.1065, 0.0165, -0.0681],\n", + " [-0.0065, 0.0277, 0.0404, -0.0816, 0.0433],\n", + " [-0.0590, -0.0959, -0.0631, 0.1114, 0.0987],\n", + " [ 0.1034, 0.0678, 0.0872, -0.0155, -0.0635]],\n", + "\n", + " [[ 0.0577, -0.0598, -0.0779, -0.0369, 0.0242],\n", + " [ 0.0594, -0.0448, -0.0680, 0.0156, -0.0681],\n", + " [-0.0752, 0.0602, -0.0194, 0.1055, 0.1123],\n", + " [ 0.0345, 0.0397, 0.0266, 0.0018, -0.0084],\n", + " [ 0.0016, 0.0431, 0.1074, -0.0299, -0.0488]],\n", + "\n", + " [[-0.0280, -0.0558, 0.0196, 0.0862, 0.0903],\n", + " [ 0.0530, -0.0850, -0.0620, -0.0254, -0.0213],\n", + " [ 0.0095, -0.1060, 0.0359, -0.0881, -0.0731],\n", + " [-0.0960, 0.1006, -0.1093, 0.0871, -0.0039],\n", + " [-0.0134, 0.0722, -0.0107, 0.0724, 0.0835]]],\n", + "\n", + "\n", + " [[[-0.1003, 0.0444, 0.0218, 0.0248, 0.0169],\n", + " [ 0.0316, -0.0555, -0.0148, 0.1097, 0.0776],\n", + " [-0.0043, -0.1086, 0.0051, -0.0786, 0.0939],\n", + " [-0.0701, -0.0083, -0.0256, 0.0205, 0.1087],\n", + " [ 0.0110, 0.0669, 0.0896, 0.0932, -0.0399]],\n", + "\n", + " [[-0.0258, 0.0556, -0.0315, 0.0541, -0.0252],\n", + " [-0.0783, 0.0470, 0.0177, 0.0515, 0.1147],\n", + " [ 0.0788, 0.1095, 0.0062, -0.0993, -0.0810],\n", + " [-0.0717, -0.1018, -0.0579, -0.1063, -0.1065],\n", + " [-0.0690, -0.1138, -0.0709, 0.0440, 0.0963]],\n", + "\n", + " [[-0.0343, -0.0336, 0.0617, -0.0570, -0.0546],\n", + " [ 0.0711, -0.1006, 0.0141, 0.1020, 0.0198],\n", + " [ 0.0314, -0.0672, -0.0016, 0.0063, 0.0283],\n", + " [ 0.0449, 0.1003, -0.0881, 0.0035, -0.0577],\n", + " [-0.0913, -0.0092, -0.1016, 0.0806, 0.0134]]],\n", + "\n", + "\n", + " [[[-0.0622, 0.0603, -0.1093, -0.0447, -0.0225],\n", + " [-0.0981, -0.0734, -0.0188, 0.0876, 0.1115],\n", + " [ 0.0735, -0.0689, -0.0755, 0.1008, 0.0408],\n", + " [ 0.0031, 0.0156, -0.0928, -0.0386, 0.1112],\n", + " [-0.0285, -0.0058, -0.0959, -0.0646, -0.0024]],\n", + "\n", + " [[-0.0717, -0.0143, 0.0470, -0.1130, 0.0343],\n", + " [-0.0763, -0.0564, 0.0443, 0.0918, -0.0316],\n", + " [-0.0474, -0.1044, -0.0595, -0.1011, -0.0264],\n", + " [ 0.0236, -0.1082, 0.1008, 0.0724, -0.1130],\n", + " [-0.0552, 0.0377, -0.0237, -0.0126, -0.0521]],\n", + "\n", + " [[ 0.0927, -0.0645, 0.0958, 0.0075, 0.0232],\n", + " [ 0.0901, -0.0190, -0.0657, -0.0187, 0.0937],\n", + " [-0.0857, 0.0262, -0.1135, 0.0605, 0.0427],\n", + " [ 0.0049, 0.0496, 0.0001, 0.0639, -0.0914],\n", + " [-0.0170, 0.0512, 0.1150, 0.0588, -0.0840]]],\n", + "\n", + "\n", + " [[[ 0.0888, -0.0257, -0.0247, -0.1050, -0.0182],\n", + " [ 0.0817, 0.0161, -0.0673, 0.0355, -0.0370],\n", + " [ 0.1054, -0.1002, -0.0365, -0.1115, -0.0455],\n", + " [ 0.0364, 0.1112, 0.0194, 0.1132, 0.0226],\n", + " [ 0.0667, 0.0926, 0.0965, -0.0646, 0.1062]],\n", + "\n", + " [[ 0.0699, -0.0540, -0.0551, -0.0969, 0.0290],\n", + " [-0.0936, 0.0488, 0.0365, -0.1003, 0.0315],\n", + " [-0.0094, 0.0527, 0.0663, -0.1148, 0.1059],\n", + " [ 0.0968, 0.0459, -0.1055, -0.0412, -0.0335],\n", + " [-0.0297, 0.0651, 0.0420, 0.0915, -0.0432]],\n", + "\n", + " [[ 0.0389, 0.0411, -0.0961, -0.1120, -0.0599],\n", + " [ 0.0790, -0.1087, -0.1005, 0.0647, 0.0623],\n", + " [ 0.0950, -0.0872, -0.0845, 0.0592, 0.1004],\n", + " [ 0.0691, 0.0181, 0.0381, 0.1096, -0.0745],\n", + " [-0.0524, 0.0808, -0.0790, -0.0637, 0.0843]]]])), ('bias', tensor([ 0.0364, 0.0373, -0.0489, -0.0016, 0.1057, -0.0693, 0.0009, 0.0549,\n", + " -0.0797, 0.1121]))])\n" + ] + } + ], + "source": [ + "# Check cnn_layer_2 internal parameter\n", + "print(cnn_layer_2.state_dict())" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "37fac5fd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cnn_layer_2 weight shape: \n", + "torch.Size([10, 3, 5, 5]) -> [out_channels=10, in_channels=3, kernel_size=5, kernel_size=5]\n", + "\n", + "cnn_layer_2 bias shape: \n", + "torch.Size([10]) -> [out_channels=10]\n" + ] + } + ], + "source": [ + "# Shape of weight and bias within cnn_layer_2\n", + "print(f\"cnn_layer_2 weight shape: \\n{cnn_layer_2.weight.shape} -> [out_channels=10, in_channels=3, kernel_size=5, kernel_size=5]\")\n", + "print(f\"\\ncnn_layer_2 bias shape: \\n{cnn_layer_2.bias.shape} -> [out_channels=10]\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "de8aa699", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test image original shape: torch.Size([3, 64, 64])\n", + "Test image with unsqueezed dimension: torch.Size([1, 3, 64, 64])\n", + "Shape after going through cnn_layer(): torch.Size([1, 10, 62, 62])\n", + "Shape after going through cnn_layer() and max_pool_layer(): torch.Size([1, 10, 31, 31])\n" + ] + } + ], + "source": [ + "print(f\"Test image original shape: {test_image.shape}\")\n", + "print(f\"Test image with unsqueezed dimension: {test_image.unsqueeze(dim = 0).shape}\")\n", + "\n", + "max_pool_layer = nn.MaxPool2d(kernel_size = 2)\n", + "\n", + "test_image_through_cnn_layer = cnn_layer(test_image.unsqueeze(dim = 0))\n", + "print(f\"Shape after going through cnn_layer(): {test_image_through_cnn_layer.shape}\")\n", + "\n", + "test_image_through_cnn_layer_and_max_pool = max_pool_layer(test_image_through_cnn_layer)\n", + "print(f\"Shape after going through cnn_layer() and max_pool_layer(): {test_image_through_cnn_layer_and_max_pool.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "98f234f4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Random tensor:\n", + "tensor([[[[0.3367, 0.1288],\n", + " [0.2345, 0.2303]]]])\n", + "Random tensor shape: torch.Size([1, 1, 2, 2])\n", + "\n", + "Max pool tensor:\n", + "tensor([[[[0.3367]]]]) <- this is the maximum value from random_tensor\n", + "Max pool tensor shape: torch.Size([1, 1, 1, 1])\n" + ] + } + ], + "source": [ + "torch.manual_seed(42)\n", + "\n", + "random_tensor = torch.randn(size = (1, 1, 2, 2))\n", + "\n", + "print(f\"Random tensor:\\n{random_tensor}\")\n", + "print(f\"Random tensor shape: {random_tensor.shape}\")\n", + "\n", + "max_pool_layer = nn.MaxPool2d(kernel_size = 2)\n", + "max_pool_tensor = max_pool_layer(random_tensor)\n", + "\n", + "print(f\"\\nMax pool tensor:\\n{max_pool_tensor} <- this is the maximum value from random_tensor\")\n", + "print(f\"Max pool tensor shape: {max_pool_tensor.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "474337e0", + "metadata": {}, + "outputs": [], + "source": [ + "loss_function = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.SGD(params = third_model.parameters(), lr = 0.1)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "36aaf9d6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r", + " 0%| | 0/3 [00:00