diff --git a/U_Net_Tensorflow.ipynb b/U_Net_Tensorflow.ipynb new file mode 100644 index 0000000..a58d390 --- /dev/null +++ b/U_Net_Tensorflow.ipynb @@ -0,0 +1,488 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 175, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import random\n", + "\n", + "import numpy as np\n", + "import cv2\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 176, + "metadata": {}, + "outputs": [], + "source": [ + "warnings.filterwarnings('ignore', category=UserWarning, module='skimage')\n", + "seed = 2019\n", + "random.seed = seed\n", + "np.random.seed = seed\n", + "tf.seed = seed" + ] + }, + { + "cell_type": "code", + "execution_count": 177, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Sequence are a safer way to do multiprocessing. This structure guarantees that the \n", + "network will only train once on each sample per epoch which is not the case with generators\"\"\"\n", + "\n", + "class ProcessData(keras.utils.Sequence):\n", + " \"\"\"\n", + " Args:\n", + " ids: It can be train or test ids\n", + " **path: It can be train or test path**\n", + " batch_size: No. of images in the batch\n", + " img_size: Size of image we need to reshape into\n", + " \"\"\"\n", + " \n", + " def __init__(self, ids, path, batch_size=8, img_size=128 ): #constructor\n", + " self.ids = ids\n", + " self.path = path\n", + " self.batch_size = batch_size\n", + " self.img_size = img_size\n", + " \n", + " #To load corresponding image and mask using id \n", + " def __load__(self, id_name):\n", + " \n", + " img_path = os.path.join(self.path, id_name, 'images' , id_name + \".png\")#1 image path\n", + " mask_path = os.path.join(self.path, id_name, 'masks/') #mask folder path\n", + " all_masks = os.listdir(mask_path) #list of masks in folder\n", + " \n", + " \n", + " img = cv2.imread(img_path, 1) #reads in rgb format\n", + " image = cv2.resize(img, (self.img_size, self.img_size)) #resize image\n", + " \n", + " mask = np.zeros((self.img_size, self.img_size, 1))\n", + "\n", + " for mask_file in all_masks:\n", + " \n", + " mask_img_path = mask_path + mask_file #1 mask path\n", + " mask_img = cv2.imread(mask_img_path, -1) #reads in gray-scale format\n", + " mask1 = resize(mask_img, (self.img_size, self.img_size)) #(128,128) but we need 2 add 1 more size\n", + " mask1_exp = np.expand_dims(mask1, axis=-1) #(128,128,1)\n", + "\n", + " mask = np.maximum(mask, mask1_exp) #mask is updated\n", + " \n", + " img = image/255\n", + " mask = mask/255\n", + " \n", + " return img, mask\n", + " \n", + " def __getitem__(self, index):\n", + "\n", + " #\n", + " if(index+1)*self.batch_size > len(self.ids):\n", + " self.batch_size = len(self.ids) - index*self.batch_size\n", + " \n", + " #Batch of files (based on batch_size)\n", + " files_batch = self.ids[index*self.batch_size : (index+1)*self.batch_size]\n", + " \n", + " image = []\n", + " mask = []\n", + " \n", + " #Getting images and masks from batch of files\n", + " for id_name in files_batch:\n", + " _img, _mask = self.__load__(id_name)\n", + " image.append(_img)\n", + " mask.append(_mask)\n", + " \n", + " image = np.array(image)\n", + " mask = np.array(mask)\n", + " \n", + " return image, mask\n", + " \n", + " def on_epoch_end(self):\n", + " pass\n", + " \n", + " def __len__(self):\n", + " return int(np.ceil(len(self.ids)/float(self.batch_size)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": 178, + "metadata": {}, + "outputs": [], + "source": [ + "img_size = 128\n", + "train_path = '/home/sumanthmeenan/Desktop/projects/U-Net/data/train/'\n", + "batch_size = 8\n", + "epochs = 5\n", + "\n", + "train_data_ids = os.listdir(train_path) \n", + " \n", + "validation_data_size = 10\n", + "\n", + "valid_ids = train_data_ids[:validation_data_size]\n", + "train_ids = train_data_ids[validation_data_size:]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Getting X and Y" + ] + }, + { + "cell_type": "code", + "execution_count": 179, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(8, 128, 128, 3) (8, 128, 128, 1)\n" + ] + } + ], + "source": [ + "gen = ProcessData(train_ids, train_path, batch_size=batch_size, img_size=img_size)\n", + "x, y = gen.__getitem__(0)\n", + "print(x.shape, y.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Display Image and Mask" + ] + }, + { + "cell_type": "code", + "execution_count": 180, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 180, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAACuCAYAAAA4eMYdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJztnXuMXNd937+/eS7J2SW5pCmTlOSHJAiIBcMWlcRU04pGQEKJGNgBrCBm0JhOYDpATdRSi8KOAYVxEURA0y0QqkijoDadQnQSo3BjU65D1ugyf0gsLLWRrTzkyIpkkqIocfnYHe7uPE//2DlXZ8/+zuPO3Nm5u/p9gMXO3Mc5v3N37/f8zvecuUNKKQiCIAjrl8KoAxAEQRCGiwi9IAjCOkeEXhAEYZ0jQi8IgrDOEaEXBEFY54jQC4IgrHOGJvRE9CARvURELxPRF4ZVjyDkHbkXhFFDw1hHT0RFAD8CsB/ABQDfB/BJpdTfZV6ZIOQYuReEPDCsjP5nALyslHpFKdUE8GcAPjakugQhz8i9IIyc0pDK3Q3gvPH+AoCfdR1MRFHDCiJKXo/qE71mDCZZxOMqe1j1rRZEhFKphGq1mmxrt9tot9vodDoA/O3R18U45opS6l3DijdjUt0LQPz9IKxdNm/ejDvuuAPA0v/33NwcXn75ZXS73dRlKaWCwjEsoQ9CREcAHEkCKZWglEpuak7UmRs+ec+JpH1crJDa9ZkUi8VlMSql0O12oZRKzouNz4aIkh/XfvMfwVWfry5X2+zrrdtm1xmDPle/LpVKuOOOO/DBD34QwNLf+saNGzh//jxee+013Lx5M7mGXJx2bO12+7VUAa0B7PtBWN98/vOfx7Fjx5L3SimcOnUKhw4dQr1ez7y+YVk3FwHcZry/tbctQSn1pFLqPqXUffpmLhQK5n5WMGMyWfMYruPgjjNxjRyKxSJKpRI2bdqEjRs3JqLvq8O1PSS0dhvMOuz6fJ2DiRZuLcScsHIxuTrYUD32tm63i263i1KphImJCezcuROTk5PLOk87zjT15pTgvQAsvx9WLbIMqFQqGB8fR61WG3UoaxoiwsGDB/Hoo48OpfxhCf33AdxFRO8jogqAXwXwLW8ghUK0GMdmsLECzHUo5jH6Z8OGDbjzzjvx4Q9/GB/4wAewY8cOlEqlZeLpq8eugxN3n6iZAujK/kNtttvmG72E6vJhdyKzs7NYWFjAwsICgCWB2LBhA7Zs2YJiseiM3U4A1iCp74W1Qq1Wwze+8Q2cP38er7zyCh566KFRh7RmePbZZ9FoNNBoNJJtRISjR48OpdMcinWjlGoT0ecA/BWAIoCvKKX+1nN8KnuAyxbNfZz42XaEeZ7LttBoIbr11lvx0z/909i0aRMajQY2b96MF154ATMzM+h0Os5MtNvtRom9r62uzsSsK9QOX/3may2svk4gjaeulEK9XsfMzAwAYOvWrYnQVyqVYOYe05a8kvZeWEs8+uij+KVf+qXk73fy5El84hOfwJkzZ0YcWf555plnMDc3BwDL5q7M+yFLhubRK6W+A+A7KY5n33PixR3vs0Ls7T5BscvWHjMAvPvd78bOnTtRqVSwsLCAQqGA119/HdevX0er1WI7EHveITR6cOETeLuu2DJ9WX+hUEChUECxWEw6Yt2ZhbA7vG63i0ajgYsXlxyLyclJjI+Po91uo9ForGiDq51rWOxT3Qt5R2ecR48eXfY/NDExgUcffRRnz55Fs9kcVXhDo1arJf+Dg/ro9Xodx48fBwAcO3YsuY7PPPPMsiw/K9b0mFgQBEEIM7JVNy7sbNDEl6Xbvq5dnqscM4t3ZYxmVlkoFFAul5PVKBMTE4m1Y5Znv+beu9rE+eiFQsHZFm7y1lWXa2WLfVylUsHk5CRqtRqazSauX7+Oer2OTqeTWFGxk9lEhHa7nVg3L7/8Mnbs2IFGo4E33niDHSnEjtyE1Uf/DSqVyop9999/P6rV6rrL6B966CF89atfRbVaRaPRwBNPPIGpqamBMvupqSkAwJ49e/DAAw+g2WxiampqKNcud0LP+e++FRn6HNMeKZfLyUoOvV5bzwH4xMm1T6/1vnDhAmZmZnD77bcnZc3Pzyf7Y7CtnJh9nLiVSqXkRmu322i1WivmCdLYIGYdhUIBk5OTuPXWW1Gr1VAsFlGv1/GjH/0I169fBwDvnIqrXP0PfPHiRczMzEAphYWFBe/1i7XkBDfacgDeni8ZBP335QRJTzKuNpVKJfG6s2ijSa1Ww4kTJ7B9+/Zk27Fjx7Bnzx48/PDDfQuzjvHhhx9GtVrNPG6T3Fs35j+oLfp6RYbOssvlMrZs2YIdO3Zg165d2LVrF9797ndjYmJi2eoYrnz7ve2rK6UwMzODF198ERcvXsTc3BzOnz+PK1euJAIbk+HanZZrtQ03X1AoFDA2NoaxsTHs3r0bd955J+68807cdttt2Lx5MyqVyrKJVN+kNReLplQqYfPmzcky0vHxcdxyyy3YtWtXMlnk8/ft9/pHL69sNBqYnZ1FvV5Hq9Vi22qTdtXPO51arYZarYbHHnsMr7zyCi5cuIALFy5ksjqmXq+jXq/j05/+NBYXF5PtSqlV9+d1O/XqH93Gxx57DJOTk5nUoUe49rYDBw5g3759A5ffbDYxNzc3NJEHhvSsm9RBECk94Wlnm65VJebrYrGIzZs3A1ha0aFFX9sdi4uLeOuttzA3N8faBNx7cwWKfl0ulzE+Po5du3ahVqthZmYGr7/+Om7evOm0H3wTjcx1YN/rDq1YLGL37t0AgNtvvx3VahWdTgftdhvFYhH/9E//hEuXLqHZbLLXzxZU12olvYx09+7dKJVKyc/MzAxeeOEFXL9+3Tsx6/pb2THZq3u4ttsTwkopNBqN59UaW2+eBsrgk7GPPfYYgOUTfZrZ2dlMVsdUKhV84xvfwAMPPAAAOHv2LH7t135tqILF1Q9g2eofAElidvjwYTz99NMD1TM+Po4LFy5gYmJixb5jx47hd3/3dwcqf1BUnj8Za9Nvh6MFasOGDQCwbPimf1cqFWzcuDH5Bwwt53N5/J1OB/V6Ha+++ioKhQJarRYajcaKT3W62pdmuSMXT6VSwbZt25LXurxSqYRyuYzbb78d9XqdFWKu83Rdg263i1arlXy4qVAooN1uo1qtYsOGDZidnU1lt7g6VteyU3Ol09atW5N5gtnZWczPzzvrFZao1Wr4zGc+A4C3uiYmJnD06FFMT0+j1Wr1XU+z2UxsBwBoNBqrms3v27cPBw4cAMAnSdu3b8eJEydw99134+rVq6sW1zDQK50+97nPLZsniO1UcyP0JjEThuZ+pRTa7XayrVgsJpk4ESVZb9oYtDiafnSn00k+9BNacmhn87EZvWteQts2epspxkopbNiwAbt27cLi4uIKQbQ7Gt9Io9PpYH5+Ho1GA2NjYygUCuh0Ouh0Osk1DXUWrrK5a8TdpO9619KjbN7znvegXC6j2+1iZmYGP/nJT5LrL/DU63X8yZ/8CQA+oweAH/7whwOJvKbZbI5s4nXv3r3J/eBi27Zt+OpXvzqQl66UGunkcqVSwVNPPQUAOHjwIAqFArrdLu6//358/OMfjyojVx49N8FqTyyav83Xs7Ozie8LLIlws9nE4uIi5ubmcPPmzSix5TA9Zt1p6B+zzDQesmu1jSl8Pr9bW1MAkk6sVCphx44dGB8fX2aLhKwqO5Zut4sbN27gxo0bWFxcRLPZRKfTQbPZTDqXGGJE3kZPpm/fvh3bt29PvNFisYgtW7ZgfHw8qu53OlNTU5iamsKpU6eWXedut4srV65genp6dMGtIkSEj370owP93+g176OyufXI5cCBA8l9XSgUcPDgQezcuTOqjFwJvQk3Ucll+lqAdWZx5coVXL58GdeuXcPMzAzeeustzMzMYHFxcdmTEs3zfXaJHZNeXsg9dTGU2ZvlpvmnMa2jVquFVquVdDI609ZiPDY2hsnJSZTL5RVt9bXNjnd+fh4XLlzApUuXcO3aNVy9ehUXL15MOswsJkZd16JaraJSqaBSqaDb7aJYLKLT6aBUKskzVSLRE6aHDh3CsWPH8NZbb+G73/0uvvSlL+Huu+9+R316tVar4ciRwZ4XNzU1hW9/+9tYXFxcZj2uRqavRy726EXP3cWQG+smZpiv4ewHLbyNRgOtVgvFYjHxzkOf6DRtDS6Ofnz1UOw6bu69y87QGbV+b09SVqtVdLtdjI+Po1wuo9FosGLKtc9+oFyn08Hs7CyazSY2bNiQiL/+R/e1I9R2+3hdtx69VavV5J+6VColnVmxWFz2mQUhTL1ex5e//OVkffZ6Wt/OLTrg4FbNpKVer+Phhx/Gvn378JGPfARElPjko0RbnCFyI/QxQuESY/N1moybq9MUWpe1YItcrNCFVuG4Mm8dU6fTwY0bNwAAW7ZsQblcRrVaxeLiIogo6eB0Nmx66a4OhKubiJIJWXtOIjQ68d143ByE+Tc10SMS/eEbXa4IfX+s1kqY1eSP//iPk6c9mmvch0Wz2cTp06dx+vTpoddl1+u6r2KTzNxaNxw+ITSPMS0Wn4duC2EIVweTdiTiisF3nm7TtWvXcO3aNVy/fn2ZN1+pVBLf3hRR80cvNzXr5mIzLTM9J2F+VsDX3ph/PHPkYsdTLBbRbrcxPz+P+fn5xLMvFArJpLMgAMDVq1dx/Pjxkfrnq8GTTz6ZWHE2N2/ejCpjTQm9IAiCkJ7cWDdp8FkrAILPYrHLAlYugeQsFjvrjcleQzZFzChFZ+LaJweAK1euYMOGDRgbG0s+IEa09MiHhYWFJNvX5eglWVzb7Xa6Rkqh9fFp8F07/clZANi4cWPi4bdarXVpQQj9oz3yo0ePOu2bRqOBZ599djXDyhTzufXm6iGlVPKo4xBrLqNPa5fY5+nXtm1gipy9rNO2V2LXkftiCQm83TnofXpi8tq1a3jzzTeTSVjtXbdaLVy7dm2Zr2fXw80z+OIzr1fM0ko7Zu6167roFUSXL1/G5cuXceXKFdy8eROtVgvXr19POgBBAJbsm6tXr+Lw4cPO/425uTk888wzqxxZdujlncePH1/2yImZmRm88cYbUWWsuYzeNcnHTZAC/WfhZl2uSVOuE/DFGxK5UJymUDYaDVy+fBndbhc7duxIPhl8/fp1XL16NVl+yU2Auiah7bhjJqN97bLbEVrGao7EdKZy/vz5JKufn5+XD0sJLE8//TQOHTqEEydOYHJyMhkFdrtdHD9+fM2PBB9//HEAwLlz5/DII49g7969OH78ePQDFXPzrJs0qyl8tow58RjzaAKbkMjHxtfPuZwom9vN1/phbvqbmqrVKgqFAm7evJl8yMmMQZdpXhOfQPvis22ufnGNmMy6zIfR6YeiLS4uyrNuBJbJyUkcOXJk2WMZ0jwqYC2gn9SpHzmhIp51syaFPlBW8ptbcRPK6vtdfZNmLX2oXE70bEE0OzS9WkVbK9wjmUMWik/c7fkKU3jtMlwWVKhc3zaz3b0PqYjQC0KPGKHv26MnotuI6H8T0d8R0d8S0b/ubT9GRBeJ6G96P7/Ybx0p41kmMnY275p01dvs1y5PnhtNpBV5ex7ALt/GFlBT0M1PxprLIE2xd80zhLA9ea5zs8u05zrs81xl+eYFzPbmITHhyNv9IAgmg3j0bQD/Rin1f4loHMDzRKQ/V/2flFJ/kLZAnyXjOpYTE185nMC7hDuNHx0LlyHbsbgyevM4Lhbuvd1BuNrlitUVkys+ru5Qp2Ke7xoJxMyHjJjM7wdByIq+hV4pdQnApd7rOSL6ewC7BwmmX0/bfm0Lk090QhOkofmAfucAssD233VMrro44U4r/v1MvPr+Htx7rj7f6CAPDON+EISsyGR5JRG9F8CHAfyf3qbPEdEPiOgrRLQ1soyoY2KPs39zP7Y9wJXjE5bYpzi64rezVE4Azd+udvrg2ucSZNf1dY1CuPNcIwhXeXasIcFfC2RxPwhClgws9ERUA/DfAXxeKTUL4I8A3AHgQ1jKcP6j47wjRPQcET1nbHP61iE7xhY0zpv2WR12+fY+lwCmwTUJ6spauXpdnVM/Qugqm9vva6vPtnHFZrfV16mFOrw8keX9IAhZMdCqGyIqAzgF4K+UUlPM/vcCOKWUuidQTvJVgja2zWB783q7VR7r6drnmRk55/n75gGyxmc1+er1ZeQu68bcZ3eO9vFpM29zv+vvFVMWNwrQtFqtXK66yfJ+GEqAwrpEDXnVDQH4rwD+3vynJiLzSfi/DODFNOVyGblRdlQGaJZlP9hMb3N9jZ2rrmGJvD3i8I06OOysnmuv/drVAXAdRZrjbczHD7uIHRXlPZsf1v0gCFkwyKqbfwbgXwL4IRH9TW/bbwP4JBF9CIAC8CqAz8YUFsr8zOOA5R2BSxxNH94+31eGS9hixcZ3vGuUEMrgfWW6RjgxnntsWzhCsfisNt9Iyzw3dFyOyPR+EIQsyc0HpsrlcjA7t85ht6es12tX9GvdxAq9q7wsxM2O3dWRmrGEPH/XdRkEV5m2hWa2K6/WTVaIdSOkYajWTZbEZJihicNYODFz+dqcyPhi4tBZtf6xRxeurDutheOq2z7fHMUQEfv8eu58+33aTN91rq/9vs5GEIR4ciH0aXzoQbLIWFG1fXPXca59Pl/cZ6fo32kmM31wHrvdLvMrBO1OiWtbzOjGde3seHydKteWUL2CIPDkQug1MdZNzI3u84ZD59iWAWdrcMRmqaH6XYIXOpeL33WcXb6Z1XP4OiauLp8l5YrPVT83ChIEIR25EvqYLJo7Xr+OEQGXaPiyUFedvjpC9oV9LDcZa8fnm/xMI4BmZm1+v24snKXDXVdfZxDbsdgeviAI6cmN0Jvi48qiXSszOHvERYz4xE4Kuzzsfn13267xiZwrgw+NZlzXLHQNfR1gTGfjK888Ju3IQhCEMLkReluAOaF3ZbSmz6zLSrtSJu2x9vGmGLt8cW6fazsXv291jK8Dc9Xjus6hMuz6zFh9f7NQZxwbtyAI6ciN0HPZbeyqFk5cfIISKjs0MvBZKKFYfXW7LAvXsbEWl6sOfZ6rgzKPT/P3cL23O+A0q23M/YIgpCM3Qm8TEkV9TD+4smK7XFNsY5YWhiZSzQzadazLi3YJO7cqxzW6iBXQkP2TBb4OWCZfBSFbciv0giAIQjbkTuh9Q3+bQTx1rg5fJhtbV6gO13vX+a66uUlY87VtuYSsJZ9Pn3ZVT79wIx2xagRhcHIl9JwgxXjfrrJituk6fPsHJbSM09yetkOwz+MmPznrxjdhzG2zOw8uFkEQ8kmuhN7lZYeEpJ9VI4Meb8KtgbfLDvnktnfPZeyxGblvzbk9UR27Coirg3sdc4wvdkEQsidXQg+ku+ldE7ZphSOtNWEeb098hsQtZmWOb/WLjWtJJjeBzK2gMa0dLstPa+n4Vum4rnPomsjSSkEYjFwJfVohcAmA/fyWmHrNMkNwGTonqmnnBHzn+GJzCfygHWHaWOzrYn6xS8zozNc+l20kCEKYQZ5HnzlZLevzLRUMWSC+jiZGpHxZPWebmMf1I4LmvhhBj/XduSWmrnhDk70x8yyhcyWjF4T+yVVGH8OgGR2Xvcf432lW09iZbcwcQmz53P5Qx5V2olfHE+q0YnHZROb+UPtDoyRBENzkTujNTJLLVGO9a7ss137XJGmsbx9ascNZPCEbwvb/ff66yzpyld+vneNrJzcnYIu5PdlslpnGLhPrRhDSkzuh5wTRJwixloZP0M3f5rm+SUe7rNiM3HWOy1Kx2+Eq126Hvd9+FDHR0peO6O36tyuD5jJxrq2+yVju2DS+vIi8IPRHrjz6EIPc6CGxSOOPx4wU7HNdnY3PC7ezXy6DjxF/fb4p6oVCAaVSKXlMcafTSSZPbQGOyaRDVoxv9BIauenYJaMXhkGlUkG1WsW9996LF198EY1GA/V6fdRhZcrAQk9ErwKYA9AB0FZK3UdEkwD+HMB7sfSFyL+ilLoWKAdA9s9TGYanG7JruBhc4u7L6s3XtrBzWX+M6BeLRdRqNVSrVZRKS3/+druN+fl5LCwsAABarRYbq+t6uoR9UDvIrj/vZHUvCKtHrVbDU089hQceeADVahXNZhONRgNPPPEEpqamAGBdiH5W1s1HlVIfUm9/YfMXAHxPKXUXgO/13gvCOwG5F4TcQYNmSr0s5j6l1BVj20sA9imlLhHRTgDTSqm7PWWoUqkU9H/7ibUfGyJNedx2lx3DHaPtEnPtf0zWHJrI1a/t9xs3bsTWrVtRLpcTC6fb7aLVamFubg4AMDs7i1arlWqy1K7fbi933blr6bJxzPedTud5Q0hzQxb3Qu+c/A9f1glf/OIX8Xu/93vsXNqpU6cAAIcOHcp1Vq+UCgpaFhm9AnCaiJ4noiO9bbcopS71Xr8B4Bb7JCI6QkTPEdFzvWD1dr6SPjskn8j3YwvYE7Qu2yFkvZgx+CY601wPuz3dbnfZahcAKJfLicgXi0WUSiUUi0VUKhVs2rQJmzZtwtjYmLd9PrjOzYWOz9feNebL93UvACvvB2F1qFQq7P8XEeHgwYM4ePAgjh49OoLIsiWLydifU0pdJKIdAM4Q0T+YO5VSistQlFJPAngSeDuDGZanruFEhNtuHu+bONUi5Mu2uQlVn89uH++Kx+drc8fridhWq4V2u41yuQxgSWybzWYy4QkApVIpyfR98biwOzkXsaMSXc4aEPy+7oXevhX3gzB8nn32WSwuLmJsbGzFPv3/VqlUVjuszBlY6JVSF3u/3ySibwL4GQCXiWinMVx9M1SOT+SzmpDjhNa3j7MVOGvCl3Wbx8RO3tp1xbRLr0oJHddoNDA7O4tarYZOp5Ocp5RKviS83W6zMekyfLGH/oZmOVz77GWeXNafV7K6F4TVY3p6GqdPn8aBAwdWiL3+v2s2m6MIDQCwf/9+7N27N4nhiSee6MtGGkjoiWgTgIJSaq73+gCALwP4FoBPAXi89/svB6mHs1xsiyAy3qQ81z5zf8iPd4mVL6Plsnhdpt0pcCOG2Dhd9XQ6Hdy8eRPtdhuVSgWlUik5X/8zmf/YXN2u653Wy+diJCIUi8Vlx+oOKM+s1r0gZEuz2cTDDz+Mffv24ZFHHsHevXuT/82zZ88CAI4fPz6S2Gq1Gk6ePIlt27Ytuwd///d/P3VZg2b0twD4Zu9mLQE4qZT6LhF9H8BfENFvAngNwK8MWA+AlUN6U2xDnrVLsLjtrrJd9bh8aft8X6xcG+3yQ+2z93G2i/69sLCARqOBYrGYHKsFVa+pD3U0ofhiRjuu9thlrwHrZlXvBSE7ms0mTp8+jenpaVSr1WR7o9FI9o8CIkK5XM7E0h541U0WEJEyBcdzHAD/hJ9tl/TTvpjMnxNR33uX8Nplcz672TZue6gd3Hvf+bbQau+es7Ps47n3sbi8ebu8vK66yQrx6AVgaW7gkUcewdjYWNLpcNaNilh1kxuh18sr0/jxMbZCrM3j6kRi6o+1hLjyQ961q8NyjR5C8fbbDtPLjx3puOJ3td2Heb4IvSC8TYzQ5/IRCLGC7/PC0xxj12mPCAbxpu1y0toPdkx2ndw2bkTj8/ZdnYl5nF7vz1kpsXZWKPPnynO1XxCEeHIh9FwGa+93ZYChjDcN3W7XKfZczKHMNrSPw9cRhGIxz48RVlcnyI0yXKMPDvO6+Dok7rzQNhF7QUhP7p5eKQiCIGRLLjJ6G1d2Gcqs9TH9TsaaK0xck64uC8U+ljvXjNfVVpenzcXquhau/ZyNFMqifZaTK/a019/n19sjA0EQ0pObjN7nX9tCELJ6uH22Bx8SSk7gXALssktcwmuf5yrL5VlztoivTLMMU4BjPoxkxhAzX+I6zo6Fi8dXrm532jkOQRByktFzYuoTWI6QRx8aGXDn2iLty+S5WOxOhhMpc+miK5bQdrtt5qdk08wF+ASam7hNM0kd6gBjO29BENKTi4zezl5dK0Rc+ITWJ7KhLNt1PJfpm7Hb9kUoY43JlH24RH8Q7EyeE+00lgrXEbu2p/lbC4IQJjcZfUzGHcogzWP0cTECbwuW3fH4/PK0mXFaYs6L6RxdGbWrLtecQkz9vjjsMu19euWT71xBENKRS6G3CQk7l22GPH99jPk7VjA5eykmNlc9rnNiJzPNWO26uPdpRwlpjiWiRKxD9ht3HVwjK8nkBaF/cmHd2JYHEBZ/fQzn57smIfV+rn5uHxeDS+Q5+6afrDRm/sC+Rtw2l/ATLf8icJ9lleZv4ZtcNb39mI7MNdqIiUMQhJXkIqO3ic1kXSLhErp+/P+Y+GIy/H4EysyMY9qjJ2FdE7wh68uXRbuunb2Pu6ahjs4+n+sMYmw7QRB4ciH0Si09OXEYN3Eog3TtD/nbLg/fPN5n+bhElRsZmHXbsZu/9VcD6oy92+0mjy4wl1K6RJObp0gzz8HF6IPrlLkRhoi80A+VSgXj4+PYs2cP7r33XgDAk08+CQC4evXqKENbdXLxULNCobDiO2N9k4U+sbbfx3redibssnh8vn+sd+/DLCOmTC3IpVIJpVIpeayp/WyaZrOJdru9rGMKZeyuWHxk/f/EZfPdblceaiZ42b9/Px599FHcf//9qFQqybdEaYE/fPgwnn766VGGmBlqrTy9slAoKPPLJmKG6r7sOySWsW32WQy+bYPisy5sCoUCqtUqKpXKsi/ssEW51Wqh0WisEPsQLisldJzdHn1MTH2hb8uSp1cKPh566CGcPHkSExMTzmNmZ2fxiU98AmfOnFnFyIZDjNDnYjJWEARBGB65EHqXFePL6rQ1oc+3jzezS678mFUc5vncqhFdlqsdMXCxpFmJUiwWl30loLmqxsS0dVy4rnnsCMv8zZXrmnMw229n87ET84IAAJOTkzhx4oQ3mweAiYkJ7Nu3b3WCygG5EPp+iBUsW4i5CUifANk+tvkoY66ctJidlF2Gb/JWbysWiyvEUX8VoC2g+sdup3mtXJPMvvmTUKdsl2nXy80VcJ2qCL4Q4rOf/Sy2bdsWdeyRI0dQq9WGHFE+6HvVDRHdDeDPjU3vB/AYgC0APgPgrd7231aQ3XrvAAAUtklEQVRKfSe23CxuZt+KGLMOTug4wTXLDWW0XPmu40OZuy2E3FyEXaYe6WhBN9vIbePeh+J1rf4x4wp1DHa5vuN9o4W8MKz7QRgelUol1/9TWZLJZCwRFQFcBPCzAD4NoK6U+oMU56tSaanP6SeerCZEY1fNhITXZe2YdbjqdsXFUSgUUCqVMDY2Bn399PG2fdPpdNBqtTA/P58sZQ3ZW76/hUuMfat0XB0Kl+EHRlW5nozN4n4YVmzrncnJSbz00kvYvn178NgrV67gfe9734rvYF1rrOZk7M8D+LFS6rVBCrFFIVa8s+qVbU+eyzRjVp/44gpl+zFzBmb9nU4HzWZzmV2jj7VHKK1Wi/0QVuycQEj47fK4EYOr/a6yXZ1CzsnkfhDSc/XqVRw+fBhXrlwJ/r8eP358zYt8LFkJ/a8C+Lrx/nNE9AMi+goRbU1bWMxkrI3PckmDS+B8QmRnsaaQusTU7kxComu3TR/X7XbRarWwsLCAZrOJVquVHN9ut9HpdNDpdLC4uIhOp7MirpgOlavfbic3xxEq03y9RgQ8lkzvByEdTz/9NO6++24cO3YMs7Ozy+6txcVFLC4u4tvf/jampqZGHOnqMbB1Q0QVAK8D+IBS6jIR3QLgCgAF4N8D2KmU+g3mvCMAjvTe7ikWi0G7wM5CzUnIGF88DaZocyIdk5Hb9kXIY/cR8vu1VVMsFpdN0HY6HQBvT9C6sm67fVwbfN5+yHpKcw1D5HkdfVb3w2rFu97Zv38/9u7dm/wPnjt3DgAwPT2NZrM54uiyIca6yULoPwbgXymlDjD73gvglFLqnkAZyif0Pm83izkGDp8QxQq961iX/RMq02VjmFm57W3b55if/uWE38YWeC6Dj/HhY+D+nty2nAt9JvfDcKIT1iOr5dF/EsYwlYh2Gvt+GcCLsQWZWbS9nTs2jY+cNnsMZdAubDvGZ0NxVo+5zy7ThRZs/Wybbre7LIN3/fjKdnW4dptclhQ3mWr+bV3n+eIYVqeeMZndD4KQFQNl9ES0CcBPALxfKXWjt+2/AfgQloaqrwL4rFLqUqCc5BEIXJZLRCu+kAJwZ8v9WgJcppq1d+xbieJqSygO2x/n1vr7Os+Q9cJZNdz7LPHVmdeMPsv7YcihCuuIVbFusoB6yytdNobL6rB9enO7T0AHIW15sSLt88LN42JJOyEaEno7m08TR5YdQ2/kkkuhzwoReiENq2XdCIIgCDkmN0JvZ33m9rSTsa6VLmlj4Uhbnv1pVPO17XWHbKk0Mfm8eZe37npvr3Yyt7va5SrXFXNo3iPmOEEQeHIj9JrQsD6t+PW7+iNErIDZz4V3ncuJo2ty2o7DJ9JcBxIqI9QJmNeWO8/VSdmTwKEY7IldV/yCIPjJxTdMaVyZo8ujjy1zGPiWXqapP9Qhxfjothiafnra+YRQ3bF/g7R/J64Ng5QnCMLb5CajjxEQn20wTGKW/qU93rVE0VxuabfRd33MDtLVWdrbuYw5pvzQMf2MpmJiWM2/uSCsJ3Ij9Gktm9DQnzu2X+Hxre0P2RxcmaEVLvpHr4PnRN91rl2f61yX5eK7VrF/o0FGYJK5C0L25EboTWJFxsxGfULITRKa+/oVNf3jElgXsWJmlu/ryFztC51nx2PGnWY0wZWV9niNq0OSTF4Q+ic3Qu/yZ7MSTd/5/ZbPiVBs5p8mPl+MrnkN+3zfBHBoctU8JkTscS4ridtvnycIQjpyI/RANmKSxuePERJfVhtjTXDLK+36QpZTTHtjPe7QfrvjMOMLdSpcXfb1s+PlRg5p7CpBEMLkRuhjVp24tsVknq7s2xYun3i7rB8fsRaQL+OOKceVjZtiqZTyfqWg/Z4Ta67jCom/Ld6uzD1mPkEQhPTkRuiBdGJvYvrYrv2ulSCcmKXNIkOjhBiyzlx9doydTcfEHGsBufb7ri/3t7Etq5CFJQiCm1wJvSAIgpA9ufrAFBCeHDUxM0juHNvSGUZG6MpOOdLWb45UYj34mGO5T+v64K6xyy6LWULKnWefY2byksULwmDkKqO3LZZ+rRyzvLTncOf6SDsZmmYVTqhsn40S8svTev7cdtdkaagD5o73LZMVoReEwciV0AP+D/mYx3Decz/Cx5HGc7dFyleWeUwa8Y6dHOayYV/c5jFmB2Bn0SGP3T4u9vqZx/uOy2IORBDeyeRK6GOzTWCl2NgrSXR5ruNDZbsyUm5Fik/s09TrsiliV/bECDy3ykj/5gTdZa/4Rkuu87g6+rGzBEFIR66E3keM2NmdRCjztM93lWu/962ND41IfO3gfG3uPSfSrjjS1G8fY2f7oTJCHbVrxOCLL40FJAgCz5oQetfEHicqMb586LzY7bGWhq9e13F2ma518a7z0k602iJt/uYsGfMcn2XFjRBcf09uu328WDeCkJ4ooSeirxDRm0T0orFtkojOENE/9n5v7W0nIvpDInqZiH5ARPcOGqTPo06bIafBVbZetZK2DF+srgnOkKdvnmu+NzuD0CjBFwv33qwjNDLy2VE+C8jVoY06sx/1vSAI/RCb0Z8A8KC17QsAvqeUugvA93rvAeAXANzV+zkC4I9iKggJdr+EOgNNP4Jo+9sxdfjK4jJyV6fCiZ/92xbjWBvEPCftCCnk8XNlxY7WcsIJDPleEISsiRJ6pdRfA7hqbf4YgK/1Xn8NwMeN7X+qljgHYAsR7UwTlC9LDAmPz0ZIU2eINCt+fHXYomrbIdy10PtjOyJ9nssCMTuE0DGuerifUGwxnUho32qz2veCIGTBIB79LUqpS73XbwC4pfd6N4DzxnEXetvCwTArZ2xibYd+zkuTlSulVmTbtuDH+vEuy4I7xhTjNHMXZvnmuXbGb+/T59hluzrUfkZGXLtdE7d5En2DzO8FQciSTCZj1dLdl+oOJKIjRPQcET3XKyM4XI+dQOWOdW13+c+uc+3jQn59bHtcgq3xjQ5CE5j28ZzFY9dbKBRWWECxIhszktHlhuZfYsvKC/3cC8DK+0EQsmQQob+sh6G932/2tl8EcJtx3K29bctQSj2plLpPKXXfADEIQh4Y6F4A5H4QhssgQv8tAJ/qvf4UgL80tv96b8XBRwDcMIa1TmxvmMvoXK9NdBm2BeRa5eGbD+BidNUZysi5fVx7OH+cszTMfVxmHmN52KMZ26c39/vmDbg4ffRzvWL2j5BM7wVByJqoh5oR0dcB7AOwnYguAPgdAI8D+Asi+k0ArwH4ld7h3wHwiwBeBjAP4NNpAup2u8tE32VDuGwH13ZzH/fe9qtj4awf7hiuTdwxIf/dh221+OLjvHjuWtrXJaadsbHGzmGE6l1NVvNeEISsoDxkSUSkSqWlPse+sbn3oSzcJ26FQmFomWEa8UojmJwo90tMWbHxcdfZfs2NUGJjdNHpdJ5fzxYHEY3+phTWDEqpoCjk5pOxtmh0u110u91louKbjLWzTs62IKKodemubWkni13H+DoETuT6EXdzcpuzgcxraoqxK8M3j+WuMxdjjN1jHsvF79ovCEI8uXkevU/EY1akxApnGn95UMH1ibYrq/bVEapfd2S2BRSKxbefs5Ts82NE2be6hiuXKy+LEY0gvBPJTUbvysY50mSKWQhDKCbXBGnIuoi1oGIJjXhCE80xoya7Pp+H7ipT7/ONDOxrzs3VCIIQR24yeoBfaeISCdtr5rLLLD3zmOO4dnDt4drGZc39+PKuziMmy+fOzaJuDm5EFuq8ReQFoT9yJfQhu0YfY4s8J46+MkL1prF37Dp98Zq46unXj3e1wWfX+GwdXywuPz5mG1dGaO4jFI8gCH5yY924sAXLFnF7IlH1VtaYn+yMZRAPOOStcxOZXP36N2eJuKwdl7duH+vqAAcVUBFpQcg3uRd6e5UK5+VyEFHUs3OyiC9mf8zcA9cms3y9NNTuNEL12/ZWP6KbxpbR+0MjAs6qM2PmOnFBENKTG+vG58mbYm8Lv7nfPl+LPdHyZZW+OvqN20W/6/Y5UTbbkIXFkyYObrtvpRNXhm0VuewsrqPzTewKguAn9xm9IAiCMBi5EXqXhxyaPAytHDG3F4vFZUsf09oZ/dg0g2ahrmvhwxzJuCyjWF89tISSK9deCsktP+XmECRrF4ThkBvrxsQ3ARlaKeKycbRXb1sp9tJJ30RlaLlm2tU6MfhWzdhwq5C4c7nO0SXErvp9q4y419w1t88LWTah6y8IAk9uMvoQrkyPyxb1du63PsbM5M1VOly2b5JmAjQLQh2bKzZfDDHtctXVb7u4bD3G408zmhAEgSeXGb0PLXiFQiGZnAwtbbTPtffpJZlpRCxtpmlnydzKGTMbtzuoLLNZ1/XSo55+Ri02dttiOiqurWlGNIIg8ORS6H0ervk6tEzPhc+escXOtTLEJ0SukYd5vO+8QbNmVwyDiqTvuofqcl1H17aQLScIQjy5EnrfZBwnCvZ55nbOd7YJZfjlchmVSiXZ3mq10G630W632Y7HFnPufWjOgdvvs6bsY1ydEDeC8MH5/XY7YgTfLM+M1zeXwMVubxNWn8nJSRw5cgSVSgUAcO7cOUxPT6PZbI44MiFEroQ+TSbrWnnjmlzlRMX8MJUpJoVCARs2bECtVkv+qYGldezNZhNzc3NoNptQSi37ohTO/uEEzWfDcEIdm+3bouzKkLnOJaZjtK8V17n4roMrVlfHJOSHhx56CCdOnMDk5GTyd2o0Gpiensbjjz+Os2fPjjhCwUeuhJ4j5E2HslRbRLQw23WY28rlMmq1GsbGxqC/EMXc1+120W630el0WEEN2TSu7Jvz8M34fELMCXloQjXUgYTa5ZrojhFpX0bvaqdk9KNh//79OHnyJCYmJpZtHxsbw4MPPggAePbZZyWzzzG5XHUTI2IxFo+9esMUJvtHo7P5crm8QrT0640bNyb7zR8zBnubvd0UUJ9IhzolGy6TdollSGhdI4q0Vk2ofHtfqIMSVpd9+/atEHl7/759+1YvICE1QaEnoq8Q0ZtE9KKx7T8Q0T8Q0Q+I6JtEtKW3/b1EtEBEf9P7+S9pAwpleIP6ta7MVW8vFArYuHEjqtVqsrKn2Wyi2WyCiFAqlZKlmDGiFxJmMwZfB8e9d9Xnapt+z+2PEebYGMzjYkTbNWKJvcaryWrfD6OmVqvhyJEjow5DGJCYjP4EgAetbWcA3KOU+iCAHwH4orHvx0qpD/V+fiubMAUhN5yA3A/CGiMo9EqpvwZw1dp2WinV7r09B+DWLIKJzRa5rNRXju0l6yxRfyetPWGrM3ad0Xc6HXQ6neSchYUFtNttr2Xjsm6432nanoY0ZbpGAr6fGFz2mmty1szizW15YTXvhzxARMsWJHBMT09jenp6dQIS+iILj/43APxP4/37iOj/EdFZIvrnaQqKFRCXvcNNYLqO584H3l5ZowUeAIrFIorFIjqdDur1Oubm5tBut5OOwi7HZUXY9ozLRgm1vd8O0ayDs3T6vf6xto/LRrKvi9lRhqynHJLZ/ZAHGo0GnnnmGef+2dlZTE1NyURszhlo1Q0RfQlAG8BTvU2XANyulJohoj0A/gcRfUApNcucewTACvOPmwA1RZLbZ59v7nfEnew3y9AZ+9zc3LJz9cqbbreLer2OVqvlFSnzccKuttltsWN34TrfbofdsYQmXmPh/g6+8n1tipmMNeMPlTdqhnE/jJpms4mpqSns2bMH1Wp1xf5Dhw7hzJkzI4hMSANFZnDvBXBKKXWPse0wgM8C+Hml1LzjvGkA/1Yp9VygfFUsFu1tzuN9MbssAZ/om1k4ESUZvN5nLrHU2T6XzdvxhYTVlfX7sMsOtW9QuLZwdaYVYK5jcNk89vGdTud5pdR9qSrMkNW4H7KKNStqtVpy/bWVc88998j6+RyglArefH1l9ET0IIB/B+AB85+aiN4F4KpSqkNE7wdwF4BX+qnDJSJcFuwa2puPM7AzQ+54/brT6Sxbb99ut5NjXHYNhy7DJYqucmKyb1e23o/g28LNZdS+0Yhr1JA2Jt9IYZD2DZvVuB9GTb1eX7FNRH7tEBR6Ivo6gH0AthPRBQC/g6VVBVUAZ3o34Dm1tKLgXwD4MhG1AHQB/JZS6ipb8HKudDqdmwCu9NUKB9pjHzHbkXG7csSo2vaeEdQJYNXuhzqAl4YQfh5Yr/dDru+FKOtmNSCi50Y5HB8W67VdwPpu2yhZz9d1vbYt7+3K5SdjBUEQhOwQoRcEQVjn5Enonxx1AENivbYLWN9tGyXr+bqu17blul258egFQRCE4ZCnjF4QBEEYAiMXeiJ6kIheIqKXiegLo45nUIjoVSL6IS09rfC53rZJIjpDRP/Y+7111HHGQPyTGtm20BJ/2Ps7/oCI7h1d5GuX9XQ/yL2Qn3thpEJPREUA/xnALwD4KQCfJKKfGmVMGfFRtfS0Qr3c6gsAvqeUugvA93rv1wInsPJJja62/AKWPhB0F5Y+yv9HqxTjumGd3g9yL+TgXhh1Rv8zAF5WSr2ilGoC+DMAHxtxTMPgYwC+1nv9NQAfH2Es0SjmSY1wt+VjAP5ULXEOwBYi2rk6ka4b3gn3g9wLI2DUQr8bwHnj/YXetrWMAnCaiJ6npQdVAcAtSqlLvddvALhlNKFlgqst6/Fvudqst2so98ISI/875v47Y9cgP6eUukhEO7D0kfh/MHcqpVQeH1rVD+upLcJQkHshJ4w6o78I4Dbj/a29bWsWpdTF3u83AXwTS8Pxy3ro1vv95ugiHBhXW9bd33IErKtrKPdCwsj/jqMW+u8DuIuI3kdEFQC/CuBbI46pb4hoExGN69cADgB4EUtt+lTvsE8B+MvRRJgJrrZ8C8Cv91YcfATADWNYK8Sxbu4HuRdydi+oiK+MG+YPgF/E0vds/hjAl0Ydz4BteT+AF3o/f6vbA2Ablmbl/xHA/wIwOepYI9vzdSx9eUYLSz7jb7raAoCwtGLkxwB+COC+Uce/Fn/Wy/0g90K+7gX5ZKwgCMI6Z9TWjSAIgjBkROgFQRDWOSL0giAI6xwRekEQhHWOCL0gCMI6R4ReEARhnSNCLwiCsM4RoRcEQVjn/H8G0bIQT+YWnQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "#we have 8 examples in x, select random number(b/w 0 to 8) and view that image and mask\n", + "r = random.randint(0, len(x)-1)\n", + "\n", + "fig = plt.figure()\n", + "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", + "ax = fig.add_subplot(1, 2, 1)\n", + "ax.imshow(x[r])\n", + "ax = fig.add_subplot(1, 2, 2)\n", + "ax.imshow(np.reshape(y[r], (img_size, img_size)), cmap=\"gray\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Different Convolutional Blocks" + ] + }, + { + "cell_type": "code", + "execution_count": 181, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def down_block(x, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n", + " c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(x)\n", + " c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n", + " p = keras.layers.MaxPool2D((2, 2), (2, 2))(c)\n", + " return c, p\n", + "\n", + "def up_block(x, skip, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n", + " us = keras.layers.UpSampling2D((2, 2))(x)\n", + " concat = keras.layers.Concatenate()([us, skip])\n", + " c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(concat)\n", + " c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n", + " return c\n", + "\n", + "def bottleneck(x, filters, kernel_size=(3, 3), padding=\"same\", strides=1):\n", + " c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(x)\n", + " c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation=\"relu\")(c)\n", + " return c" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# UNet Model" + ] + }, + { + "cell_type": "code", + "execution_count": 182, + "metadata": {}, + "outputs": [], + "source": [ + "def UNet():\n", + " f = [16, 32, 64, 128, 256]\n", + " inputs = keras.layers.Input((img_size, img_size, 3))\n", + " \n", + " p0 = inputs\n", + " c1, p1 = down_block(p0, f[0]) #128 -> 64\n", + " c2, p2 = down_block(p1, f[1]) #64 -> 32\n", + " c3, p3 = down_block(p2, f[2]) #32 -> 16\n", + " c4, p4 = down_block(p3, f[3]) #16->8\n", + " \n", + " bn = bottleneck(p4, f[4])\n", + " \n", + " u1 = up_block(bn, c4, f[3]) #8 -> 16\n", + " u2 = up_block(u1, c3, f[2]) #16 -> 32\n", + " u3 = up_block(u2, c2, f[1]) #32 -> 64\n", + " u4 = up_block(u3, c1, f[0]) #64 -> 128\n", + " \n", + " outputs = keras.layers.Conv2D(1, (1, 1), padding=\"same\", activation=\"sigmoid\")(u4)\n", + " model = keras.models.Model(inputs, outputs)\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 183, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model_3\"\n", + "__________________________________________________________________________________________________\n", + "Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + "input_4 (InputLayer) [(None, 128, 128, 3) 0 \n", + "__________________________________________________________________________________________________\n", + "conv2d_57 (Conv2D) (None, 128, 128, 16) 448 input_4[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_58 (Conv2D) (None, 128, 128, 16) 2320 conv2d_57[0][0] \n", + "__________________________________________________________________________________________________\n", + "max_pooling2d_12 (MaxPooling2D) (None, 64, 64, 16) 0 conv2d_58[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_59 (Conv2D) (None, 64, 64, 32) 4640 max_pooling2d_12[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_60 (Conv2D) (None, 64, 64, 32) 9248 conv2d_59[0][0] \n", + "__________________________________________________________________________________________________\n", + "max_pooling2d_13 (MaxPooling2D) (None, 32, 32, 32) 0 conv2d_60[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_61 (Conv2D) (None, 32, 32, 64) 18496 max_pooling2d_13[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_62 (Conv2D) (None, 32, 32, 64) 36928 conv2d_61[0][0] \n", + "__________________________________________________________________________________________________\n", + "max_pooling2d_14 (MaxPooling2D) (None, 16, 16, 64) 0 conv2d_62[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_63 (Conv2D) (None, 16, 16, 128) 73856 max_pooling2d_14[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_64 (Conv2D) (None, 16, 16, 128) 147584 conv2d_63[0][0] \n", + "__________________________________________________________________________________________________\n", + "max_pooling2d_15 (MaxPooling2D) (None, 8, 8, 128) 0 conv2d_64[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_65 (Conv2D) (None, 8, 8, 256) 295168 max_pooling2d_15[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_66 (Conv2D) (None, 8, 8, 256) 590080 conv2d_65[0][0] \n", + "__________________________________________________________________________________________________\n", + "up_sampling2d_12 (UpSampling2D) (None, 16, 16, 256) 0 conv2d_66[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_12 (Concatenate) (None, 16, 16, 384) 0 up_sampling2d_12[0][0] \n", + " conv2d_64[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_67 (Conv2D) (None, 16, 16, 128) 442496 concatenate_12[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_68 (Conv2D) (None, 16, 16, 128) 147584 conv2d_67[0][0] \n", + "__________________________________________________________________________________________________\n", + "up_sampling2d_13 (UpSampling2D) (None, 32, 32, 128) 0 conv2d_68[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_13 (Concatenate) (None, 32, 32, 192) 0 up_sampling2d_13[0][0] \n", + " conv2d_62[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_69 (Conv2D) (None, 32, 32, 64) 110656 concatenate_13[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_70 (Conv2D) (None, 32, 32, 64) 36928 conv2d_69[0][0] \n", + "__________________________________________________________________________________________________\n", + "up_sampling2d_14 (UpSampling2D) (None, 64, 64, 64) 0 conv2d_70[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_14 (Concatenate) (None, 64, 64, 96) 0 up_sampling2d_14[0][0] \n", + " conv2d_60[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_71 (Conv2D) (None, 64, 64, 32) 27680 concatenate_14[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_72 (Conv2D) (None, 64, 64, 32) 9248 conv2d_71[0][0] \n", + "__________________________________________________________________________________________________\n", + "up_sampling2d_15 (UpSampling2D) (None, 128, 128, 32) 0 conv2d_72[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_15 (Concatenate) (None, 128, 128, 48) 0 up_sampling2d_15[0][0] \n", + " conv2d_58[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_73 (Conv2D) (None, 128, 128, 16) 6928 concatenate_15[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_74 (Conv2D) (None, 128, 128, 16) 2320 conv2d_73[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_75 (Conv2D) (None, 128, 128, 1) 17 conv2d_74[0][0] \n", + "==================================================================================================\n", + "Total params: 1,962,625\n", + "Trainable params: 1,962,625\n", + "Non-trainable params: 0\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "model = UNet()\n", + "model.compile(optimizer=\"adam\", loss=\"binary_crossentropy\", metrics=[\"accuracy\"])\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training the model" + ] + }, + { + "cell_type": "code", + "execution_count": 184, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/5\n", + "2/2 [==============================] - 8s 4s/step - loss: 0.6737 - accuracy: 0.8045 - val_loss: 0.6439 - val_accuracy: 0.8043\n", + "Epoch 2/5\n", + "2/2 [==============================] - 4s 2s/step - loss: 0.4776 - accuracy: 0.7697 - val_loss: 0.0312 - val_accuracy: 0.7887\n", + "Epoch 3/5\n", + "2/2 [==============================] - 4s 2s/step - loss: 0.0438 - accuracy: 0.8628 - val_loss: 0.1271 - val_accuracy: 0.7887\n", + "Epoch 4/5\n", + "2/2 [==============================] - 4s 2s/step - loss: 0.1546 - accuracy: 0.8628 - val_loss: 0.0510 - val_accuracy: 0.7887\n", + "Epoch 5/5\n", + "2/2 [==============================] - 4s 2s/step - loss: 0.0799 - accuracy: 0.7118 - val_loss: 0.0211 - val_accuracy: 0.7887\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 184, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_gen = ProcessData(train_ids, train_path, img_size=img_size, batch_size=batch_size)\n", + "valid_gen = ProcessData(valid_ids, train_path, img_size=img_size, batch_size=batch_size)\n", + "\n", + "train_steps = len(train_ids)//batch_size\n", + "valid_steps = len(valid_ids)//batch_size\n", + "\n", + "model.fit_generator(train_gen, validation_data=valid_gen, steps_per_epoch=train_steps, validation_steps=valid_steps, \n", + " epochs=epochs)" + ] + }, + { + "cell_type": "code", + "execution_count": 185, + "metadata": {}, + "outputs": [], + "source": [ + "model.save_weights(\"UNetW.h5\")\n", + "\n", + "## Dataset for prediction\n", + "x, y = valid_gen.__getitem__(1)\n", + "result = model.predict(x)\n", + "\n", + "result = (result > 0.5).astype(np.uint8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plt.figure()\n", + "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", + "\n", + "ax = fig.add_subplot(1, 2, 1)\n", + "ax.imshow(np.reshape(y[0]*255, (img_size, img_size)), cmap=\"gray\")\n", + "\n", + "ax = fig.add_subplot(1, 2, 2)\n", + "ax.imshow(np.reshape(result[0]*255, (img_size, img_size)), cmap=\"gray\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}