From 27bf8916355f962bc362dd30949ddc9f59a80fd3 Mon Sep 17 00:00:00 2001 From: "Romain F. Laine" Date: Wed, 20 Jan 2021 11:44:29 +0000 Subject: [PATCH] v1.12 --- .../3D-RCAN_ZeroCostDL4Mic.ipynb | 2 +- .../DenoiSeg_2D_ZeroCostDL4Mic.ipynb | 1 - .../DenoiSeg_ZeroCostDL4Mic.ipynb | 1 + .../SplineDist_2D_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/ChangeLog.txt | 12 + Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb | 2 +- .../Deep-STORM_2D_ZeroCostDL4Mic.ipynb | 2 +- .../Latest_ZeroCostDL4Mic_Release.csv | 2 +- .../Noise2Void_2D_ZeroCostDL4Mic.ipynb | 2 +- .../Noise2Void_3D_ZeroCostDL4Mic.ipynb | 2 +- .../StarDist_2D_ZeroCostDL4Mic.ipynb | 2 +- .../StarDist_3D_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/Template_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/U-Net_2D_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb | 2355 +---------------- Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb | 2 +- 20 files changed, 30 insertions(+), 2371 deletions(-) delete mode 100644 Colab_notebooks/Beta notebooks/DenoiSeg_2D_ZeroCostDL4Mic.ipynb create mode 100644 Colab_notebooks/Beta notebooks/DenoiSeg_ZeroCostDL4Mic.ipynb diff --git a/Colab_notebooks/Beta notebooks/3D-RCAN_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/3D-RCAN_ZeroCostDL4Mic.ipynb index 658e0552..c5f972a1 100644 --- a/Colab_notebooks/Beta notebooks/3D-RCAN_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Beta notebooks/3D-RCAN_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"3D-RCAN_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **3D-RCAN**\n","\n","\n","\n","---\n","\n","3D-RCAN is a neural network capable of image restoration from corrupted bio-images, first released in 2020 by [Chen *et al.* in biorXiv](https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1). \n","\n"," **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper: \n","\n","**Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes**, by Chen *et al.* published in bioRxiv in 2020 (https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1)\n","\n","And source code found in: https://github.com/AiviaCommunity/3D-RCAN\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://www.dropbox.com/sh/hieldept1x476dw/AAC0pY3FrwdZBctvFF0Fx0L3a?dl=0).\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install 3D-RCAN and dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"1kvDz2Ft4FX6"},"source":["## **2.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = ['1.11.1']\n","\n","\n","#@markdown ##Install 3D-RCAN and dependencies\n","\n","!git clone https://github.com/AiviaCommunity/3D-RCAN\n","\n","import os\n","\n","\n","!pip install q keras==2.2.5\n","\n","!pip install colorama; sys_platform=='win32'\n","!pip install jsonschema\n","!pip install numexpr\n","!pip install tqdm>=4.41.0\n","\n","\n","\n","%tensorflow_version 1.x\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"IWhWPjyi33M2"},"source":["## **2.2. Restart your runtime**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"SBag2atY36Js"},"source":["** Here you need to restart your runtime to load the newly installed dependencies**\n","\n"," Click on \"Runtime\" ---> \"Restart Runtime\""]},{"cell_type":"markdown","metadata":{"id":"_nRLOjuk3_8z"},"source":["## **2.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"cellView":"form","id":"TYYBwn_54G9j"},"source":["Notebook_version = ['1.11.1']\n","\n","#@markdown ##Load key dependencies\n","\n","!pip install q keras==2.2.5\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","%tensorflow_version 1.x\n","import tensorflow\n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to 3D-RCAN -------\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","print('Notebook version: '+Notebook_version[0])\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","!pip freeze > requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: 256**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`num_residual_groups`:** Number of residual groups in RCAN. **Default value: 5** \n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the num_residual_groups value until the OOM error disappear.**\n","\n","**`num_residual_blocks`:** Number of residual channel attention blocks in each residual group in RCAN. **Default value: 3** \n","\n","**`num_channels`:** Number of feature channels in RCAN. **Default value: 32** \n","\n","**`channel_reduction`:** Channel reduction ratio for channel attention. **Default value: 8** \n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","#@markdown ###Path to training images:\n","\n","# base folder of GT and low images\n","base = \"/content\"\n","\n","# low SNR images\n","Training_source = \"\" #@param {type:\"string\"}\n","lowfile = Training_source+\"/*.tif\"\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","GTfile = Training_target+\"/*.tif\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# create the training data file into model_path folder.\n","training_data = model_path+\"/my_training_data.npz\"\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 30#@param {type:\"number\"}\n","number_of_steps = 256#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","percentage_validation = 10 #@param {type:\"number\"}\n","num_residual_groups = 5 #@param {type:\"number\"}\n","num_residual_blocks = 3 #@param {type:\"number\"}\n","num_channels = 32 #@param {type:\"number\"}\n","channel_reduction = 8 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\") \n"," percentage_validation = 10\n"," num_residual_groups = 5\n"," num_channels = 32\n"," num_residual_blocks = 3\n"," channel_reduction = 8\n"," \n","\n","percentage = percentage_validation/100\n","\n","\n","full_model_path = model_path+'/'+model_name\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n"," \n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","#Load one randomly chosen training source file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","# Here we split the data between training and validation\n","# Here we count the number of files in the training target folder\n","Filelist = os.listdir(Training_target)\n","number_files = len(Filelist)\n","\n","File_for_validation = int((number_files)/percentage_validation)+1\n","\n","#Here we split the training dataset between training and validation\n","# Everything is copied in the /Content Folder\n","\n","Training_source_temp = \"/content/training_source\"\n","\n","if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n","os.makedirs(Training_source_temp)\n","\n","Training_target_temp = \"/content/training_target\"\n","if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n","os.makedirs(Training_target_temp)\n","\n","Validation_source_temp = \"/content/validation_source\"\n","\n","if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n","os.makedirs(Validation_source_temp)\n","\n","Validation_target_temp = \"/content/validation_target\"\n","if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n","os.makedirs(Validation_target_temp)\n","\n","list_source = os.listdir(os.path.join(Training_source))\n","list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n"," \n"," \n","for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n","for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n","\n","list_source_temp = os.listdir(os.path.join(Training_source_temp))\n","list_target_temp = os.listdir(os.path.join(Training_target_temp))\n","\n","\n","#Here we move images to be used for validation\n","for i in range(File_for_validation):\n","\n"," name = list_source_temp[i]\n"," shutil.move(Training_source_temp+\"/\"+name, Validation_source_temp+\"/\"+name)\n"," shutil.move(Training_target_temp+\"/\"+name, Validation_target_temp+\"/\"+name)\n","\n","\n","#Load one randomly chosen training target file\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Low SNR image (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('High SNR image (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_3D_RCAN.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n","\n","**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**"]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","cellView":"form"},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = False #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = False #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source_temp,Training_target_temp,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source_temp,Training_target_temp)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","print(\"Preparing the config file...\")\n","\n","if Use_Data_augmentation == True:\n"," Training_source_temp = Saving_path+'/augmented_source'\n"," Training_target_temp = Saving_path+'/augmented_target'\n","\n","# Here we prepare the JSON file\n","\n","import json \n"," \n","# Config file for 3D-RCAN \n","dictionary ={\n"," \"epochs\": number_of_epochs,\n"," \"steps_per_epoch\": number_of_steps,\n"," \"num_residual_groups\": num_residual_groups,\n"," \"training_data_dir\": {\"raw\": Training_source_temp,\n"," \"gt\": Training_target_temp},\n"," \n"," \"validation_data_dir\": {\"raw\": Validation_source_temp,\n"," \"gt\": Validation_target_temp},\n"," \"num_channels\": num_channels,\n"," \"num_residual_blocks\": num_residual_blocks,\n"," \"channel_reduction\": channel_reduction\n"," \n"," \n","}\n"," \n","json_object = json.dumps(dictionary, indent = 4) \n"," \n","with open(\"/content/config.json\", \"w\") as outfile: \n"," outfile.write(json_object)\n","\n","print(\"Done\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n"]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["#@markdown ##Start Training\n","\n","start = time.time()\n","\n","# Start Training\n","!python /content/3D-RCAN/train.py -c /content/config.json -o \"$full_model_path\"\n","\n","print(\"Training, done.\")\n","\n","\n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = '3D-RCAN'\n","day = datetime.now()\n","datetime_str = str(day)[0:16]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate and Time: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and method:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs (image dimensions: '+str(shape)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","\n","
ParameterValue
number_of_epochs{0}
number_of_steps{1}
percentage_validation{2}
num_residual_groups{3}
num_residual_blocks{4}
num_channels{5}
channel_reduction{6}
\n","\"\"\".format(number_of_epochs,number_of_steps, percentage_validation, num_residual_groups, num_residual_blocks, num_channels, channel_reduction)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(32, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_3D_RCAN.png').shape\n","pdf.image('/content/TrainingDataExample_3D_RCAN.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- 3D-RCAN: Chen et al. \"Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes.\" bioRxiv 2020 https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w8Q_uYGgiico"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","%load_ext tensorboard\n","%tensorboard --logdir \"$full_QC_model_path\"\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","path_QC_prediction = path_metrics_save+'Prediction'\n","\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_QC_prediction):\n"," shutil.rmtree(path_QC_prediction)\n","os.makedirs(path_QC_prediction)\n","\n","\n","# Perform the predictions\n","\n","print(\"Restoring images...\")\n","\n","!python /content/3D-RCAN/apply.py -m \"$full_QC_model_path\" -i \"$Source_QC_folder\" -o \"$path_QC_prediction\"\n","\n","print(\"Done...\")\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack_raw = io.imread(os.path.join(path_metrics_save+\"Prediction/\",thisFile))\n"," test_prediction_stack = test_prediction_stack_raw[:, 1, :, :]\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction_raw = io.imread(os.path.join(path_metrics_save+'Prediction/', Test_FileList[-1]))\n","\n","img_Prediction = img_Prediction_raw[:, 1, :, :]\n","plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n","pdResults.head()\n","\n","\n","\n","#Make a pdf summary of the QC results\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = '3D RCAN'\n","\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='You can see these curves in the notebook.')\n","pdf.ln(3)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," mSSIM_SvsGT = header[3]\n"," NRMSE_PvsGT = header[4]\n"," NRMSE_SvsGT = header[5]\n"," PSNR_PvsGT = header[6]\n"," PSNR_SvsGT = header[7]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," mSSIM_SvsGT = row[3]\n"," NRMSE_PvsGT = row[4]\n"," NRMSE_SvsGT = row[5]\n"," PSNR_PvsGT = row[6]\n"," PSNR_SvsGT = row[7]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes, by Chen et al. bioRxiv (2020)'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","cellView":"form"},"source":["\n","#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n"," \n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","print(\"Restoring images...\")\n","\n","!python /content/3D-RCAN/apply.py -m \"$full_Prediction_model_path\" -i \"$Data_folder\" -o \"$Result_folder\"\n","\n","print(\"Images saved into the result folder:\", Result_folder)\n","\n","#Display an example\n","\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","z_mid_plane = int(x.shape[0] / 2)+1\n","\n","@interact\n","def show_results(file=os.listdir(Data_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n"," x = imread(Data_folder+\"/\"+file)\n"," y_raw = imread(Result_folder+\"/\"+file)\n"," y = y_raw[:, 1, :, :]\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Noisy Input (single Z plane)');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Prediction (single Z plane)');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using 3D-RCAN!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"3D-RCAN_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1JHO_gnWRtiFhhD5YgLE2UwgTB-63MVii","timestamp":1610723892159},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **3D-RCAN**\n","\n","\n","\n","---\n","\n","3D-RCAN is a neural network capable of image restoration from corrupted bio-images, first released in 2020 by [Chen *et al.* in biorXiv](https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1). \n","\n"," **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper: \n","\n","**Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes**, by Chen *et al.* published in bioRxiv in 2020 (https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1)\n","\n","And source code found in: https://github.com/AiviaCommunity/3D-RCAN\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://www.dropbox.com/sh/hieldept1x476dw/AAC0pY3FrwdZBctvFF0Fx0L3a?dl=0).\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install 3D-RCAN and dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"1kvDz2Ft4FX6"},"source":["## **2.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","#@markdown ##Install 3D-RCAN and dependencies\n","\n","!git clone https://github.com/AiviaCommunity/3D-RCAN\n","\n","import os\n","\n","\n","!pip install q keras==2.2.5\n","\n","!pip install colorama; sys_platform=='win32'\n","!pip install jsonschema\n","!pip install numexpr\n","!pip install tqdm>=4.41.0\n","\n","\n","\n","%tensorflow_version 1.x\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"IWhWPjyi33M2"},"source":["## **2.2. Restart your runtime**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"SBag2atY36Js"},"source":["** Here you need to restart your runtime to load the newly installed dependencies**\n","\n"," Click on \"Runtime\" ---> \"Restart Runtime\""]},{"cell_type":"markdown","metadata":{"id":"_nRLOjuk3_8z"},"source":["## **2.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"TYYBwn_54G9j","cellView":"form"},"source":["Notebook_version = ['1.11.1']\n","\n","#@markdown ##Load key dependencies\n","\n","!pip install q keras==2.2.5\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","%tensorflow_version 1.x\n","import tensorflow\n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to 3D-RCAN -------\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","print('Notebook version: '+Notebook_version[0])\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = '3D-RCAN'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:16]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate and Time: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs (image dimensions: '+str(shape)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","\n","
ParameterValue
number_of_epochs{0}
number_of_steps{1}
percentage_validation{2}
num_residual_groups{3}
num_residual_blocks{4}
num_channels{5}
channel_reduction{6}
\n"," \"\"\".format(number_of_epochs,number_of_steps, percentage_validation, num_residual_groups, num_residual_blocks, num_channels, channel_reduction)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(32, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_3D_RCAN.png').shape\n"," pdf.image('/content/TrainingDataExample_3D_RCAN.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- 3D-RCAN: Chen et al. \"Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes.\" bioRxiv 2020 https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," if trained:\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n"," else:\n"," pdf.output('/content/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = '3D RCAN'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='You can see these curves in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," mSSIM_SvsGT = header[3]\n"," NRMSE_PvsGT = header[4]\n"," NRMSE_SvsGT = header[5]\n"," PSNR_PvsGT = header[6]\n"," PSNR_SvsGT = header[7]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," mSSIM_SvsGT = row[3]\n"," NRMSE_PvsGT = row[4]\n"," NRMSE_SvsGT = row[5]\n"," PSNR_PvsGT = row[6]\n"," PSNR_SvsGT = row[7]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes, by Chen et al. bioRxiv (2020)'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","!pip freeze > requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: 256**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`num_residual_groups`:** Number of residual groups in RCAN. **Default value: 5** \n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the num_residual_groups value until the OOM error disappear.**\n","\n","**`num_residual_blocks`:** Number of residual channel attention blocks in each residual group in RCAN. **Default value: 3** \n","\n","**`num_channels`:** Number of feature channels in RCAN. **Default value: 32** \n","\n","**`channel_reduction`:** Channel reduction ratio for channel attention. **Default value: 8** \n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","#@markdown ###Path to training images:\n","\n","# base folder of GT and low images\n","base = \"/content\"\n","\n","# low SNR images\n","Training_source = \"\" #@param {type:\"string\"}\n","lowfile = Training_source+\"/*.tif\"\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","GTfile = Training_target+\"/*.tif\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# create the training data file into model_path folder.\n","training_data = model_path+\"/my_training_data.npz\"\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 30#@param {type:\"number\"}\n","number_of_steps = 256#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","percentage_validation = 10 #@param {type:\"number\"}\n","num_residual_groups = 5 #@param {type:\"number\"}\n","num_residual_blocks = 3 #@param {type:\"number\"}\n","num_channels = 32 #@param {type:\"number\"}\n","channel_reduction = 8 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\") \n"," percentage_validation = 10\n"," num_residual_groups = 5\n"," num_channels = 32\n"," num_residual_blocks = 3\n"," channel_reduction = 8\n"," \n","\n","percentage = percentage_validation/100\n","\n","\n","full_model_path = model_path+'/'+model_name\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n"," \n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","#Load one randomly chosen training source file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","# Here we split the data between training and validation\n","# Here we count the number of files in the training target folder\n","Filelist = os.listdir(Training_target)\n","number_files = len(Filelist)\n","\n","File_for_validation = int((number_files)/percentage_validation)+1\n","\n","#Here we split the training dataset between training and validation\n","# Everything is copied in the /Content Folder\n","\n","Training_source_temp = \"/content/training_source\"\n","\n","if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n","os.makedirs(Training_source_temp)\n","\n","Training_target_temp = \"/content/training_target\"\n","if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n","os.makedirs(Training_target_temp)\n","\n","Validation_source_temp = \"/content/validation_source\"\n","\n","if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n","os.makedirs(Validation_source_temp)\n","\n","Validation_target_temp = \"/content/validation_target\"\n","if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n","os.makedirs(Validation_target_temp)\n","\n","list_source = os.listdir(os.path.join(Training_source))\n","list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n"," \n"," \n","for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n","for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n","\n","list_source_temp = os.listdir(os.path.join(Training_source_temp))\n","list_target_temp = os.listdir(os.path.join(Training_target_temp))\n","\n","\n","#Here we move images to be used for validation\n","for i in range(File_for_validation):\n","\n"," name = list_source_temp[i]\n"," shutil.move(Training_source_temp+\"/\"+name, Validation_source_temp+\"/\"+name)\n"," shutil.move(Training_target_temp+\"/\"+name, Validation_target_temp+\"/\"+name)\n","\n","\n","#Load one randomly chosen training target file\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Low SNR image (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('High SNR image (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_3D_RCAN.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n","\n","**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**"]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","cellView":"form"},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = False #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = False #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source_temp,Training_target_temp,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source_temp,Training_target_temp)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","print(\"Preparing the config file...\")\n","\n","if Use_Data_augmentation == True:\n"," Training_source_temp = Saving_path+'/augmented_source'\n"," Training_target_temp = Saving_path+'/augmented_target'\n","\n","# Here we prepare the JSON file\n","\n","import json \n"," \n","# Config file for 3D-RCAN \n","dictionary ={\n"," \"epochs\": number_of_epochs,\n"," \"steps_per_epoch\": number_of_steps,\n"," \"num_residual_groups\": num_residual_groups,\n"," \"training_data_dir\": {\"raw\": Training_source_temp,\n"," \"gt\": Training_target_temp},\n"," \n"," \"validation_data_dir\": {\"raw\": Validation_source_temp,\n"," \"gt\": Validation_target_temp},\n"," \"num_channels\": num_channels,\n"," \"num_residual_blocks\": num_residual_blocks,\n"," \"channel_reduction\": channel_reduction\n"," \n"," \n","}\n"," \n","json_object = json.dumps(dictionary, indent = 4) \n"," \n","with open(\"/content/config.json\", \"w\") as outfile: \n"," outfile.write(json_object)\n","\n","# Export pdf summary of training parameters\n","pdf_export(augmentation = Use_Data_augmentation)\n","\n","print(\"Done\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n","\n"]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["#@markdown ##Start Training\n","\n","start = time.time()\n","\n","# Start Training\n","!python /content/3D-RCAN/train.py -c /content/config.json -o \"$full_model_path\"\n","\n","print(\"Training, done.\")\n","\n","\n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","%load_ext tensorboard\n","%tensorboard --logdir \"$full_QC_model_path\"\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","path_QC_prediction = path_metrics_save+'Prediction'\n","\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_QC_prediction):\n"," shutil.rmtree(path_QC_prediction)\n","os.makedirs(path_QC_prediction)\n","\n","\n","# Perform the predictions\n","\n","print(\"Restoring images...\")\n","\n","!python /content/3D-RCAN/apply.py -m \"$full_QC_model_path\" -i \"$Source_QC_folder\" -o \"$path_QC_prediction\"\n","\n","print(\"Done...\")\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack_raw = io.imread(os.path.join(path_metrics_save+\"Prediction/\",thisFile))\n"," test_prediction_stack = test_prediction_stack_raw[:, 1, :, :]\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction_raw = io.imread(os.path.join(path_metrics_save+'Prediction/', Test_FileList[-1]))\n","\n","img_Prediction = img_Prediction_raw[:, 1, :, :]\n","plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n","pdResults.head()\n","\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","cellView":"form"},"source":["\n","#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n"," \n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","print(\"Restoring images...\")\n","\n","!python /content/3D-RCAN/apply.py -m \"$full_Prediction_model_path\" -i \"$Data_folder\" -o \"$Result_folder\"\n","\n","print(\"Images saved into the result folder:\", Result_folder)\n","\n","#Display an example\n","\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","z_mid_plane = int(x.shape[0] / 2)+1\n","\n","@interact\n","def show_results(file=os.listdir(Data_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n"," x = imread(Data_folder+\"/\"+file)\n"," y_raw = imread(Result_folder+\"/\"+file)\n"," y = y_raw[:, 1, :, :]\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Noisy Input (single Z plane)');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Prediction (single Z plane)');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using 3D-RCAN!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/DenoiSeg_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/DenoiSeg_2D_ZeroCostDL4Mic.ipynb deleted file mode 100644 index ac9cc058..00000000 --- a/Colab_notebooks/Beta notebooks/DenoiSeg_2D_ZeroCostDL4Mic.ipynb +++ /dev/null @@ -1 +0,0 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"DenoiSeg_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **Image denoising and segmentation using DenoiSeg 2D**\n","\n","---\n","\n"," DenoiSeg 2D is deep-learning method that can be used to jointly denoise and segment 2D microscopy images. By running this notebook, you can train your and use you own network. \n","\n"," The benefits of using DenoiSeg (compared to other Deep Learning-based segmentation methods) are more prononced when only a few annotated images are available. However, the denoising part requires many images to perform well. All the noisy images don't need to be labeled to train DenoiSeg.\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper: **DenoiSeg: Joint Denoising and Segmentation**\n","Tim-Oliver Buchholz, Mangal Prakash, Alexander Krull, Florian Jug\n","https://arxiv.org/abs/2005.02987\n","\n","And source code found in: https://github.com/juglab/DenoiSeg/wiki\n","\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","**it needs to have access to a paired training dataset made of images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**Importantly, the benefits of using DenoiSeg are more pronounced when only limited numbers of segmentation annotations are available for training. However, DenoiSeg also expects that lots of noisy raw images are available to train the denoising part. It is therefore not required for all the noisy images to be annotated to train DenoiSeg**.\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Noisy Images (Training_source)\n"," - img_1.tif, img_2.tif, img_3.tif, img_4.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif\n"," - **Quality control dataset (optional, not required for training)**\n"," - Noisy Images\n"," - img_1.tif, img_2.tif\n"," - High SNR Images\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"cbTknRcviyT7"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"h5i5CS2bSmZr"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n3B3meGTbYVi"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install DenoiSeg and Dependencies**\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","id":"fq21zJVFNASx"},"source":["Notebook_version = ['1.11']\n","\n","#@markdown ##Install DenoiSeg and dependencies\n","!pip install q keras==2.2.5\n","\n","# Here we enable Tensorflow 1. \n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","\n","# Here we install Noise2Void and other required packages\n","!pip install denoiseg\n","!pip install wget\n","!pip install memory_profiler\n","!pip install fpdf\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to Denoiseg -------\n","\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","import numpy as np\n","from matplotlib import pyplot as plt\n","from scipy import ndimage\n","\n","from denoiseg.models import DenoiSeg, DenoiSegConfig\n","from denoiseg.utils.misc_utils import combine_train_test_data, shuffle_train_data, augment_data\n","from denoiseg.utils.seg_utils import *\n","from denoiseg.utils.compute_precision_threshold import measure_precision, compute_labels\n","\n","from csbdeep.utils import plot_history\n","from tifffile import imread, imsave\n","from glob import glob\n","\n","import urllib\n","import os\n","import zipfile\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","!pip freeze > requirements.txt\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 30**\n"," \n","**Advanced Parameters - experienced users only**\n","\n","**`Priority`:** Choose how much relative the importance to assign to the denoising \n","and segmentation tasks by choosing an appropriate value (between 0 and 1; with 0 being only segmentation and 1 being only denoising. **Default value: 0.5**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: depends on number of patches, min 100; max 400**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"ewpNJ_I0Mv47"},"source":["# create DataGenerator-object.\n","\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","Priority = 0.5#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","batch_size = 128#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," Priority = 0.5\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_target))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","# Here we count the number of files in the training target folder\n","Mask_Filelist = os.listdir(Training_target)\n","Mask_number_files = len(Mask_Filelist)\n","\n","# Here we count the number of file to use for validation\n","Mask_for_validation = int((Mask_number_files)/percentage_validation)\n","\n","if Mask_for_validation == 0:\n"," Mask_for_validation = 2\n","if Mask_for_validation == 1:\n"," Mask_for_validation = 2\n","\n","# Here we count the number of files in the training target folder\n","Noisy_Filelist = os.listdir(Training_source)\n","Noisy_number_files = len(Noisy_Filelist)\n","\n","# Here we count the number of file to use for validation\n","Noisy_for_validation = int((Noisy_number_files)/percentage_validation)\n","\n","if Noisy_for_validation == 0:\n"," Noisy_for_validation = 1\n","\n","#Here we find the noisy images that do not have masks\n","noisy_image_no_mask_list = list(set(Noisy_Filelist) - set(Mask_Filelist))\n","\n","\n","#Here we split the training dataset between training and validation\n","# Everything is copied in the /Content Folder\n","Training_source_temp = \"/content/training_source\"\n","\n","if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n","os.makedirs(Training_source_temp)\n","\n","Training_target_temp = \"/content/training_target\"\n","if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n","os.makedirs(Training_target_temp)\n","\n","Validation_source_temp = \"/content/validation_source\"\n","\n","if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n","os.makedirs(Validation_source_temp)\n","\n","Validation_target_temp = \"/content/validation_target\"\n","if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n","os.makedirs(Validation_target_temp)\n","\n","list_source = os.listdir(os.path.join(Training_source))\n","list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n","\n","for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n","for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n","#Here we move images to be used for validation\n","for i in range(Mask_for_validation): \n"," shutil.move(Training_source_temp+\"/\"+list_target[i], Validation_source_temp+\"/\"+list_target[i])\n"," shutil.move(Training_target_temp+\"/\"+list_target[i], Validation_target_temp+\"/\"+list_target[i])\n","\n","#Here we move a few more noisy images for validation\n","if noisy_image_no_mask_list:\n"," for y in range(Noisy_for_validation): \n"," shutil.move(Training_source_temp+\"/\"+noisy_image_no_mask_list[y], Validation_source_temp+\"/\"+noisy_image_no_mask_list[y])\n","\n","\n","print(\"Parameters initiated.\")\n","\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', vmin=0, vmax=1, cmap='viridis')\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_DenoiSeg.png',bbox_inches='tight',pad_inches=0)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"STDOuNOFsTTJ"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"E4QW-tvYsWhX"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis (multiply the dataset by 8). \n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," \n","\n","\n"," "]},{"cell_type":"code","metadata":{"id":"VipPCXmwL1YN","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"W6pZg0KVnPzf"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a DenoiSeg model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"l-EDcv3Wyvqb","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"keIQhCmOMv5S"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"PXcLuX5jbNUv"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"cellView":"form","id":"rBelu-LtbOTh"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","print(\"In progress...\")\n","\n","Training_source_dir = Training_source_temp\n","Training_target_dir = Training_target_temp\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","validation_images_tiff=Validation_source_temp+\"/*.tif\"\n","validation_mask_tiff=Validation_target_temp+\"/*.tif\"\n","\n","train_images = imread(sorted(glob(training_images_tiff)))\n","val_images = imread(sorted(glob(validation_images_tiff)))\n","\n","available_train_masks = imread(sorted(glob(mask_images_tiff)))\n","available_val_masks = imread(sorted(glob(validation_mask_tiff)))\n","\n","#This allows the users to not have all their training images segmented\n","blank_images_train = np.zeros((train_images.shape[0]-available_train_masks.shape[0], available_train_masks.shape[1], available_train_masks.shape[2]))\n","blank_images_val = np.zeros((val_images.shape[0]-available_val_masks.shape[0], available_val_masks.shape[1], available_val_masks.shape[2]))\n","blank_images_train = blank_images_train.astype(\"uint16\")\n","blank_images_val = blank_images_val.astype(\"uint16\")\n","\n","train_masks = np.concatenate((available_train_masks,blank_images_train), axis = 0)\n","val_masks = np.concatenate((available_val_masks,blank_images_val), axis = 0)\n","\n","\n","if not Use_Data_augmentation:\n"," X, Y_train_masks = train_images, train_masks\n","\n","# Now we apply data augmentation to the training patches:\n","# Rotate four times by 90 degree and add flipped versions.\n","if Use_Data_augmentation:\n"," X, Y_train_masks = augment_data(train_images, train_masks)\n","\n","X_val, Y_val_masks = val_images, val_masks\n","\n","# Here we add the channel dimension to our input images.\n","# Dimensionality for training has to be 'SYXC' (Sample, Y-Dimension, X-Dimension, Channel)\n","X = X[...,np.newaxis]\n","Y = convert_to_oneHot(Y_train_masks)\n","X_val = X_val[...,np.newaxis]\n","Y_val = convert_to_oneHot(Y_val_masks)\n","print(\"Shape of X: {}\".format(X.shape))\n","print(\"Shape of Y: {}\".format(Y.shape))\n","print(\"Shape of X_val: {}\".format(X_val.shape))\n","print(\"Shape of Y_val: {}\".format(Y_val.shape))\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= max(100, min(int(X.shape[0]/batch_size), 400))\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# create a Config object\n","\n","config = DenoiSegConfig(X, unet_kern_size=3, n_channel_out=4, relative_weights = [1.0,1.0,5.0],\n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n"," batch_norm=True, train_batch_size=batch_size, unet_n_first = 32, \n"," unet_n_depth=4, denoiseg_alpha=Priority, train_learning_rate = initial_learning_rate, train_tensorboard=False)\n","\n","\n","# Let's look at the parameters stored in the config-object.\n","vars(config)\n"," \n"," \n","# create network model.\n","\n","model = DenoiSeg(config=config, name=model_name, basedir=model_path)\n","\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","print(\"Setup done.\")\n","print(config)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (DenoiSeg -- DenoiSeg Predict). You can find it in your model folder (export.bioimage.io.zip and model.yaml). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"cellView":"form","id":"fisJmA13Mv5e","scrolled":true},"source":["start = time.time()\n","\n","#@markdown ##Start Training\n","%memit\n","\n","\n","\n","history = model.train(X, Y, (X_val, Y_val))\n","\n","print(\"Training done.\")\n","%memit\n","\n","\n","print(\"Training, done.\")\n","\n","threshold, val_score = model.optimize_thresholds(val_images[:available_val_masks.shape[0]].astype(np.float32), val_masks, measure=measure_precision())\n","\n","print(\"The higest score of {} is achieved with threshold = {}.\".format(np.round(val_score, 3), threshold))\n","\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate','threshold'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i], str(threshold)])\n","\n","#Thresholdpath = model_path+'/'+model_name+'/Quality Control/optimal_threshold.csv'\n","#with open(Thresholdpath, 'w') as f1:\n"," #writer1 = csv.writer(f1)\n"," #writer1.writerow(['threshold'])\n"," #writer1.writerow([str(threshold)])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='DenoiSeg', \n"," description='DenoiSeg 2D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='YX',\n"," patch_shape=(64, 64))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBDeep Fiji plugin\")\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf\n","\n","from datetime import datetime\n","\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'DenoiSeg 2D'\n","\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate and Time: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and method:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', with a batch size of '+str(batch_size)+' and using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","#text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'Data augmentation was enabled'\n","\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
batch_size{1}
number_of_steps{2}
percentage_validation{3}
Priority{4}
initial_learning_rate{5}
\n","\"\"\".format(number_of_epochs,batch_size,number_of_steps,percentage_validation,Priority,initial_learning_rate)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_DenoiSeg.png').shape\n","pdf.image('/content/TrainingDataExample_DenoiSeg.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- DenoiSeg: Buchholz, Prakash, et al. \"DenoiSeg: Joint Denoising and Segmentation\", arXiv 2020.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","\n","\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Vd9igRYvSnTr"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"sTMDT1u7rK9g"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"OVxLyPyPiv85","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WZDvRjLZu-Lm"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","id":"vMzSP50kMv5p"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lreUY7-SsGkI"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","**DenoiSeg** allow to both denoise and segment microscopy images. This section allow you to evaluate both tasks separetly.\n","\n","**Evaluation of the denoising**\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_Denoising_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","**Evaluation of the Segmentation**\n","\n","This option will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_Segmentation_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"]},{"cell_type":"code","metadata":{"id":"kjbHJHbtsg2R","cellView":"form"},"source":["#@markdown ##Choose what to evaluate\n","\n","Evaluate_Denoising = False #@param {type:\"boolean\"}\n","\n","Evaluate_Segmentation = True #@param {type:\"boolean\"}\n","\n","\n","# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_Denoising_folder = \"\" #@param{type:\"string\"}\n","Target_Segmentation_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#@markdown ###If your model was trained outside of ZeroCostDl4Mic, please provide a threshold value for the segmentation (between 0-1):\n","\n","threshold = 0.5 #@param {type:\"number\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","#Activate the pretrained model. \n","config = None\n","model = DenoiSeg(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","#Load the threshold value. \n","\n","if os.path.exists(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"threshold\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"Optimal segmentation threshold found\")\n"," #find the last learning rate\n"," threshold = csvRead[\"threshold\"].iloc[-1]\n","\n","# ------------- Prepare the model and run predictions ------------\n","# creates a loop, creating filenames and saving them\n","\n","thisdir = Path(Source_QC_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n","\n","#Here we load the images\n"," base_filename = os.path.basename(file)\n"," test_images = imread(os.path.join(r, file))\n","\n","#Here we perform the predictions\n"," predicted_channels = model.predict(test_images.astype(np.float32), axes='YX')\n"," denoised_images= predicted_channels[...,0]\n"," segmented_images= (compute_labels(predicted_channels, threshold))\n","\n","#Here we save the results\n"," io.imsave(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"+\"/\"+\"Predicted_denoised_\"+base_filename, denoised_images)\n"," io.imsave(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"+\"/\"+\"Predicted_segmentation_\"+base_filename, segmented_images)\n","\n","# ------------- Here we Start assessing the denoising against GT ------------\n","\n","if Evaluate_Denoising:\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_Denoising_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_Denoising_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",\"Predicted_denoised_\"+i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n"," Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n"," norm = simple_norm(x, percent = 99)\n","\n"," plt.figure(figsize=(15,15))\n"," # Currently only displays the last computed set, from memory\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(Target_Denoising_folder, Test_FileList[-1]))\n"," plt.imshow(img_GT, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n"," plt.imshow(img_Source, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", \"Predicted_denoised_\"+Test_FileList[-1]))\n"," plt.imshow(img_Prediction, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_denoising.png',bbox_inches='tight',pad_inches=0)\n","#________________________________________________________________________\n","# Here we start testing the differences between GT and predicted masks\n","\n","if Evaluate_Segmentation:\n","\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_Segmentation_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",\"Predicted_segmentation_\"+n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_Segmentation_folder, n))\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n"," from astropy.visualization import simple_norm\n","\n"," # ------------- For display ------------\n"," print('--------------------------------------------------------------')\n"," @interact\n"," def show_QC_results(file = os.listdir(Source_QC_folder)):\n","\n"," plt.figure(figsize=(25,5))\n"," source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n"," target_image = io.imread(os.path.join(Target_Segmentation_folder, file), as_gray = True)\n"," prediction = io.imread(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/Predicted_segmentation_\"+file, as_gray = True)\n","\n"," target_image_mask = np.empty_like(target_image)\n"," target_image_mask[target_image > 0] = 255\n"," target_image_mask[target_image == 0] = 0\n"," \n"," prediction_mask = np.empty_like(prediction)\n"," prediction_mask[prediction > 0] = 255\n"," prediction_mask[prediction == 0] = 0\n","\n"," intersection = np.logical_and(target_image_mask, prediction_mask)\n"," union = np.logical_or(target_image_mask, prediction_mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," norm = simple_norm(source_image, percent = 99)\n","\n","\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, cmap='Greens')\n"," plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)));\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_segmentation.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'DenoiSeg_2D'\n","\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_segmentation.png').shape\n","if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')\n","\n","if Evaluate_Segmentation:\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_segmentation.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_segmentation.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+\"/Quality Control/QC_metrics_Segmentation_\"+QC_model_name+\".csv\", 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," IoU = header[1] \n"," header = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,IoU)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," image = row[0]\n"," IoU = row[1] \n"," cells = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(IoU),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}
{0}{1}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n","if Evaluate_Denoising:\n","\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_denoising.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_denoising.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+\"/Quality Control/QC_metrics_Denoising_\"+QC_model_name+\".csv\", 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","print('------------------------------')\n","print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DWAhOBc7gpzN"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"KAILvLGFS2-1"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"cellView":"form","id":"bl3EdYFVS7X9"},"source":["import imageio\n","\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###If your model was trained outside of ZeroCostDl4Mic, please provide a Threshold value for the segmentation (between 0-1):\n","\n","threshold = 0.5 #@param {type:\"number\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = DenoiSeg(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","#Load the threshold value. \n","\n","if os.path.exists(os.path.join(full_Prediction_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(full_Prediction_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"threshold\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"Optimal segmentation threshold found\")\n"," #find the last learning rate\n"," threshold = csvRead[\"threshold\"].iloc[-1]\n","\n","# creates a loop, creating filenames and saving them\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","print(\"Processing...\")\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n","\n","#Here we load the images\n"," base_filename = os.path.basename(file)\n"," test_images = imread(os.path.join(r, file))\n","\n","#Here we perform the predictions\n"," predicted_channels = model.predict(test_images.astype(np.float32), axes='YX')\n"," denoised_images= predicted_channels[...,0]\n"," segmented_images= (compute_labels(predicted_channels, threshold))\n","\n","#Here we save the results\n"," io.imsave(Result_folder+\"/\"+\"Predicted_denoised_\"+base_filename, denoised_images)\n"," io.imsave(Result_folder+\"/\"+\"Predicted_segmentation_\"+base_filename,segmented_images)\n"," \n","\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PfTw_pQUUAqB"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"jFp-0y4zT_gL"},"source":["\n","# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+\"Predicted_denoised_\"+random_choice)\n","z = imread(Result_folder+\"/\"+\"Predicted_segmentation_\"+random_choice)\n","\n","norm = simple_norm(x, percent = 99)\n","\n","plt.figure(figsize=(30,15))\n","plt.subplot(1, 4, 1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off');\n","plt.title(\"Input\")\n","\n","plt.subplot(1, 4, 2)\n","plt.imshow(y, interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off');\n","plt.title(\"Predicted denoised image\")\n","\n","plt.subplot(1, 4, 3)\n","plt.imshow(z, interpolation='nearest', vmin=0, vmax=1, cmap='viridis')\n","plt.axis('off');\n","plt.title(\"Predicted segmentation\")\n","\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wgO7Ok1PBFQj"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS"},"source":["#**Thank you for using DenoiSeg!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/DenoiSeg_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/DenoiSeg_ZeroCostDL4Mic.ipynb new file mode 100644 index 00000000..92a50382 --- /dev/null +++ b/Colab_notebooks/Beta notebooks/DenoiSeg_ZeroCostDL4Mic.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"DenoiSeg_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611075289867},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **Image denoising and segmentation using DenoiSeg 2D**\n","\n","---\n","\n"," DenoiSeg 2D is deep-learning method that can be used to jointly denoise and segment 2D microscopy images. By running this notebook, you can train your and use you own network. \n","\n"," The benefits of using DenoiSeg (compared to other Deep Learning-based segmentation methods) are more prononced when only a few annotated images are available. However, the denoising part requires many images to perform well. All the noisy images don't need to be labeled to train DenoiSeg.\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper: **DenoiSeg: Joint Denoising and Segmentation**\n","Tim-Oliver Buchholz, Mangal Prakash, Alexander Krull, Florian Jug\n","https://arxiv.org/abs/2005.02987\n","\n","And source code found in: https://github.com/juglab/DenoiSeg/wiki\n","\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","**it needs to have access to a paired training dataset made of images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**Importantly, the benefits of using DenoiSeg are more pronounced when only limited numbers of segmentation annotations are available for training. However, DenoiSeg also expects that lots of noisy raw images are available to train the denoising part. It is therefore not required for all the noisy images to be annotated to train DenoiSeg**.\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Noisy Images (Training_source)\n"," - img_1.tif, img_2.tif, img_3.tif, img_4.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif\n"," - **Quality control dataset (optional, not required for training)**\n"," - Noisy Images\n"," - img_1.tif, img_2.tif\n"," - High SNR Images\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install DenoiSeg and Dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","#@markdown ##Install DenoiSeg and dependencies\n","!pip install q keras==2.2.5\n","\n","# Here we enable Tensorflow 1. \n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","\n","# Here we install Noise2Void and other required packages\n","!pip install denoiseg\n","!pip install wget\n","!pip install memory_profiler\n","!pip install fpdf\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to Denoiseg -------\n","\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","import numpy as np\n","from matplotlib import pyplot as plt\n","from scipy import ndimage\n","\n","from denoiseg.models import DenoiSeg, DenoiSegConfig\n","from denoiseg.utils.misc_utils import combine_train_test_data, shuffle_train_data, augment_data\n","from denoiseg.utils.seg_utils import *\n","from denoiseg.utils.compute_precision_threshold import measure_precision, compute_labels\n","\n","from csbdeep.utils import plot_history\n","from tifffile import imread, imsave\n","from glob import glob\n","\n","import urllib\n","import os\n","import zipfile\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'DenoiSeg 2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate and Time: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', with a batch size of '+str(batch_size)+' and using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'Data augmentation was enabled'\n","\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
batch_size{1}
number_of_steps{2}
percentage_validation{3}
Priority{4}
initial_learning_rate{5}
\n"," \"\"\".format(number_of_epochs,batch_size,number_of_steps,percentage_validation,Priority,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_DenoiSeg.png').shape\n"," pdf.image('/content/TrainingDataExample_DenoiSeg.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- DenoiSeg: Buchholz, Prakash, et al. \"DenoiSeg: Joint Denoising and Segmentation\", arXiv 2020.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","\n","\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'DenoiSeg_2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_segmentation.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')\n","\n"," if Evaluate_Segmentation:\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_segmentation.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_segmentation.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+\"/Quality Control/QC_metrics_Segmentation_\"+QC_model_name+\".csv\", 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," IoU = header[1] \n"," header = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,IoU)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," image = row[0]\n"," IoU = row[1] \n"," cells = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(IoU),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}
{0}{1}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," if Evaluate_Denoising:\n","\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_denoising.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_denoising.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+\"/Quality Control/QC_metrics_Denoising_\"+QC_model_name+\".csv\", 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n"," print('------------------------------')\n"," print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","!pip freeze > requirements.txt\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 30**\n"," \n","**Advanced Parameters - experienced users only**\n","\n","**`Priority`:** Choose how much relative the importance to assign to the denoising \n","and segmentation tasks by choosing an appropriate value (between 0 and 1; with 0 being only segmentation and 1 being only denoising. **Default value: 0.5**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: depends on number of patches, min 100; max 400**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# create DataGenerator-object.\n","\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","Priority = 0.5#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","batch_size = 128#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," Priority = 0.5\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_target))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","# Here we count the number of files in the training target folder\n","Mask_Filelist = os.listdir(Training_target)\n","Mask_number_files = len(Mask_Filelist)\n","\n","# Here we count the number of file to use for validation\n","Mask_for_validation = int((Mask_number_files)/percentage_validation)\n","\n","if Mask_for_validation == 0:\n"," Mask_for_validation = 2\n","if Mask_for_validation == 1:\n"," Mask_for_validation = 2\n","\n","# Here we count the number of files in the training target folder\n","Noisy_Filelist = os.listdir(Training_source)\n","Noisy_number_files = len(Noisy_Filelist)\n","\n","# Here we count the number of file to use for validation\n","Noisy_for_validation = int((Noisy_number_files)/percentage_validation)\n","\n","if Noisy_for_validation == 0:\n"," Noisy_for_validation = 1\n","\n","#Here we find the noisy images that do not have masks\n","noisy_image_no_mask_list = list(set(Noisy_Filelist) - set(Mask_Filelist))\n","\n","\n","#Here we split the training dataset between training and validation\n","# Everything is copied in the /Content Folder\n","Training_source_temp = \"/content/training_source\"\n","\n","if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n","os.makedirs(Training_source_temp)\n","\n","Training_target_temp = \"/content/training_target\"\n","if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n","os.makedirs(Training_target_temp)\n","\n","Validation_source_temp = \"/content/validation_source\"\n","\n","if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n","os.makedirs(Validation_source_temp)\n","\n","Validation_target_temp = \"/content/validation_target\"\n","if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n","os.makedirs(Validation_target_temp)\n","\n","list_source = os.listdir(os.path.join(Training_source))\n","list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n","\n","for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n","for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n","#Here we move images to be used for validation\n","for i in range(Mask_for_validation): \n"," shutil.move(Training_source_temp+\"/\"+list_target[i], Validation_source_temp+\"/\"+list_target[i])\n"," shutil.move(Training_target_temp+\"/\"+list_target[i], Validation_target_temp+\"/\"+list_target[i])\n","\n","#Here we move a few more noisy images for validation\n","if noisy_image_no_mask_list:\n"," for y in range(Noisy_for_validation): \n"," shutil.move(Training_source_temp+\"/\"+noisy_image_no_mask_list[y], Validation_source_temp+\"/\"+noisy_image_no_mask_list[y])\n","\n","\n","print(\"Parameters initiated.\")\n","\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', vmin=0, vmax=1, cmap='viridis')\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_DenoiSeg.png',bbox_inches='tight',pad_inches=0)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis (multiply the dataset by 8). \n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a DenoiSeg model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = True #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YOp8HwavpoON"},"source":["# **4. Train the network**\r\n","---"]},{"cell_type":"markdown","metadata":{"id":"CKSKY4icpcKb"},"source":["## **4.1. Prepare the training data and model for training**\r\n","---\r\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"cellView":"form","id":"0LM_L-5Spb2z"},"source":["#@markdown ##Create the model and dataset objects\r\n","\r\n","# --------------------- Here we delete the model folder if it already exist ------------------------\r\n","\r\n","if os.path.exists(model_path+'/'+model_name):\r\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\r\n"," shutil.rmtree(model_path+'/'+model_name)\r\n","\r\n","\r\n","# --------------------- Here we load the augmented data or the raw data ------------------------\r\n","\r\n","print(\"In progress...\")\r\n","\r\n","Training_source_dir = Training_source_temp\r\n","Training_target_dir = Training_target_temp\r\n","# --------------------- ------------------------------------------------\r\n","\r\n","training_images_tiff=Training_source_dir+\"/*.tif\"\r\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\r\n","\r\n","validation_images_tiff=Validation_source_temp+\"/*.tif\"\r\n","validation_mask_tiff=Validation_target_temp+\"/*.tif\"\r\n","\r\n","train_images = imread(sorted(glob(training_images_tiff)))\r\n","val_images = imread(sorted(glob(validation_images_tiff)))\r\n","\r\n","available_train_masks = imread(sorted(glob(mask_images_tiff)))\r\n","available_val_masks = imread(sorted(glob(validation_mask_tiff)))\r\n","\r\n","#This allows the users to not have all their training images segmented\r\n","blank_images_train = np.zeros((train_images.shape[0]-available_train_masks.shape[0], available_train_masks.shape[1], available_train_masks.shape[2]))\r\n","blank_images_val = np.zeros((val_images.shape[0]-available_val_masks.shape[0], available_val_masks.shape[1], available_val_masks.shape[2]))\r\n","blank_images_train = blank_images_train.astype(\"uint16\")\r\n","blank_images_val = blank_images_val.astype(\"uint16\")\r\n","\r\n","train_masks = np.concatenate((available_train_masks,blank_images_train), axis = 0)\r\n","val_masks = np.concatenate((available_val_masks,blank_images_val), axis = 0)\r\n","\r\n","\r\n","if not Use_Data_augmentation:\r\n"," X, Y_train_masks = train_images, train_masks\r\n","\r\n","# Now we apply data augmentation to the training patches:\r\n","# Rotate four times by 90 degree and add flipped versions.\r\n","if Use_Data_augmentation:\r\n"," X, Y_train_masks = augment_data(train_images, train_masks)\r\n","\r\n","X_val, Y_val_masks = val_images, val_masks\r\n","\r\n","# Here we add the channel dimension to our input images.\r\n","# Dimensionality for training has to be 'SYXC' (Sample, Y-Dimension, X-Dimension, Channel)\r\n","X = X[...,np.newaxis]\r\n","Y = convert_to_oneHot(Y_train_masks)\r\n","X_val = X_val[...,np.newaxis]\r\n","Y_val = convert_to_oneHot(Y_val_masks)\r\n","print(\"Shape of X: {}\".format(X.shape))\r\n","print(\"Shape of Y: {}\".format(Y.shape))\r\n","print(\"Shape of X_val: {}\".format(X_val.shape))\r\n","print(\"Shape of Y_val: {}\".format(Y_val.shape))\r\n","\r\n","#Here we automatically define number_of_step in function of training data and batch size\r\n","#Here we ensure that our network has a minimal number of steps\r\n","if (Use_Default_Advanced_Parameters): \r\n"," number_of_steps= max(100, min(int(X.shape[0]/batch_size), 400))\r\n","\r\n","\r\n","# --------------------- Using pretrained model ------------------------\r\n","#Here we ensure that the learning rate set correctly when using pre-trained models\r\n","if Use_pretrained_model:\r\n"," if Weights_choice == \"last\":\r\n"," initial_learning_rate = lastLearningRate\r\n","\r\n"," if Weights_choice == \"best\": \r\n"," initial_learning_rate = bestLearningRate\r\n","# --------------------- ---------------------- ------------------------\r\n","\r\n","# create a Config object\r\n","\r\n","config = DenoiSegConfig(X, unet_kern_size=3, n_channel_out=4, relative_weights = [1.0,1.0,5.0],\r\n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \r\n"," batch_norm=True, train_batch_size=batch_size, unet_n_first = 32, \r\n"," unet_n_depth=4, denoiseg_alpha=Priority, train_learning_rate = initial_learning_rate, train_tensorboard=False)\r\n","\r\n","\r\n","# Let's look at the parameters stored in the config-object.\r\n","vars(config)\r\n"," \r\n"," \r\n","# create network model.\r\n","\r\n","model = DenoiSeg(config=config, name=model_name, basedir=model_path)\r\n","\r\n","\r\n","\r\n","# --------------------- Using pretrained model ------------------------\r\n","# Load the pretrained weights \r\n","if Use_pretrained_model:\r\n"," model.load_weights(h5_file_path)\r\n","# --------------------- ---------------------- ------------------------\r\n","\r\n","#Export summary of training parameters as pdf\r\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\r\n","\r\n","print(\"Setup done.\")\r\n","print(config)\r\n","\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (DenoiSeg -- DenoiSeg Predict). You can find it in your model folder (export.bioimage.io.zip and model.yaml). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n","\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"cellView":"form","id":"xlcY9dvfm67C"},"source":["start = time.time()\r\n","\r\n","#@markdown ##Start Training\r\n","%memit\r\n","\r\n","\r\n","\r\n","history = model.train(X, Y, (X_val, Y_val))\r\n","\r\n","print(\"Training done.\")\r\n","%memit\r\n","\r\n","\r\n","print(\"Training, done.\")\r\n","\r\n","threshold, val_score = model.optimize_thresholds(val_images[:available_val_masks.shape[0]].astype(np.float32), val_masks, measure=measure_precision())\r\n","\r\n","print(\"The higest score of {} is achieved with threshold = {}.\".format(np.round(val_score, 3), threshold))\r\n","\r\n","\r\n","\r\n","# convert the history.history dict to a pandas DataFrame: \r\n","lossData = pd.DataFrame(history.history) \r\n","\r\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\r\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\r\n","\r\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\r\n","\r\n","# The training evaluation.csv is saved (overwrites the Files if needed). \r\n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\r\n","with open(lossDataCSVpath, 'w') as f:\r\n"," writer = csv.writer(f)\r\n"," writer.writerow(['loss','val_loss', 'learning rate','threshold'])\r\n"," for i in range(len(history.history['loss'])):\r\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i], str(threshold)])\r\n","\r\n","#Thresholdpath = model_path+'/'+model_name+'/Quality Control/optimal_threshold.csv'\r\n","#with open(Thresholdpath, 'w') as f1:\r\n"," #writer1 = csv.writer(f1)\r\n"," #writer1.writerow(['threshold'])\r\n"," #writer1.writerow([str(threshold)])\r\n","\r\n","\r\n","# Displaying the time elapsed for training\r\n","dt = time.time() - start\r\n","mins, sec = divmod(dt, 60) \r\n","hour, mins = divmod(mins, 60) \r\n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\r\n","\r\n","model.export_TF(name='DenoiSeg', \r\n"," description='DenoiSeg 2D trained using ZeroCostDL4Mic.', \r\n"," authors=[\"You\"],\r\n"," test_img=X_val[0,...,0], axes='YX',\r\n"," patch_shape=(64, 64))\r\n","\r\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBDeep Fiji plugin\")\r\n","\r\n","#Create a pdf document with training summary\r\n","\r\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","**DenoiSeg** allow to both denoise and segment microscopy images. This section allow you to evaluate both tasks separetly.\n","\n","**Evaluation of the denoising**\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_Denoising_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","**Evaluation of the Segmentation**\n","\n","This option will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_Segmentation_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose what to evaluate\n","\n","Evaluate_Denoising = False #@param {type:\"boolean\"}\n","\n","Evaluate_Segmentation = True #@param {type:\"boolean\"}\n","\n","\n","# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_Denoising_folder = \"\" #@param{type:\"string\"}\n","Target_Segmentation_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#@markdown ###If your model was trained outside of ZeroCostDl4Mic, please provide a threshold value for the segmentation (between 0-1):\n","\n","threshold = 0.5 #@param {type:\"number\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","#Activate the pretrained model. \n","config = None\n","model = DenoiSeg(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","#Load the threshold value. \n","\n","if os.path.exists(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"threshold\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"Optimal segmentation threshold found\")\n"," #find the last learning rate\n"," threshold = csvRead[\"threshold\"].iloc[-1]\n","\n","# ------------- Prepare the model and run predictions ------------\n","# creates a loop, creating filenames and saving them\n","\n","thisdir = Path(Source_QC_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n","\n","#Here we load the images\n"," base_filename = os.path.basename(file)\n"," test_images = imread(os.path.join(r, file))\n","\n","#Here we perform the predictions\n"," predicted_channels = model.predict(test_images.astype(np.float32), axes='YX')\n"," denoised_images= predicted_channels[...,0]\n"," segmented_images= (compute_labels(predicted_channels, threshold))\n","\n","#Here we save the results\n"," io.imsave(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"+\"/\"+\"Predicted_denoised_\"+base_filename, denoised_images)\n"," io.imsave(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"+\"/\"+\"Predicted_segmentation_\"+base_filename, segmented_images)\n","\n","# ------------- Here we Start assessing the denoising against GT ------------\n","\n","if Evaluate_Denoising:\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_Denoising_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_Denoising_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",\"Predicted_denoised_\"+i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n"," Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n"," norm = simple_norm(x, percent = 99)\n","\n"," plt.figure(figsize=(15,15))\n"," # Currently only displays the last computed set, from memory\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(Target_Denoising_folder, Test_FileList[-1]))\n"," plt.imshow(img_GT, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n"," plt.imshow(img_Source, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", \"Predicted_denoised_\"+Test_FileList[-1]))\n"," plt.imshow(img_Prediction, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_denoising.png',bbox_inches='tight',pad_inches=0)\n","#________________________________________________________________________\n","# Here we start testing the differences between GT and predicted masks\n","\n","if Evaluate_Segmentation:\n","\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_Segmentation_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",\"Predicted_segmentation_\"+n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_Segmentation_folder, n))\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n"," from astropy.visualization import simple_norm\n","\n"," # ------------- For display ------------\n"," print('--------------------------------------------------------------')\n"," @interact\n"," def show_QC_results(file = os.listdir(Source_QC_folder)):\n","\n"," plt.figure(figsize=(25,5))\n"," source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n"," target_image = io.imread(os.path.join(Target_Segmentation_folder, file), as_gray = True)\n"," prediction = io.imread(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/Predicted_segmentation_\"+file, as_gray = True)\n","\n"," target_image_mask = np.empty_like(target_image)\n"," target_image_mask[target_image > 0] = 255\n"," target_image_mask[target_image == 0] = 0\n"," \n"," prediction_mask = np.empty_like(prediction)\n"," prediction_mask[prediction > 0] = 255\n"," prediction_mask[prediction == 0] = 0\n","\n"," intersection = np.logical_and(target_image_mask, prediction_mask)\n"," union = np.logical_or(target_image_mask, prediction_mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," norm = simple_norm(source_image, percent = 99)\n","\n","\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, cmap='Greens')\n"," plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)));\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_segmentation.png',bbox_inches='tight',pad_inches=0)\n","\n","#Export pdf summary of QC results\n","qc_pdf_export()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["import imageio\n","\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###If your model was trained outside of ZeroCostDl4Mic, please provide a Threshold value for the segmentation (between 0-1):\n","\n","threshold = 0.5 #@param {type:\"number\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = DenoiSeg(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","#Load the threshold value. \n","\n","if os.path.exists(os.path.join(full_Prediction_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(full_Prediction_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"threshold\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"Optimal segmentation threshold found\")\n"," #find the last learning rate\n"," threshold = csvRead[\"threshold\"].iloc[-1]\n","\n","# creates a loop, creating filenames and saving them\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","print(\"Processing...\")\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n","\n","#Here we load the images\n"," base_filename = os.path.basename(file)\n"," test_images = imread(os.path.join(r, file))\n","\n","#Here we perform the predictions\n"," predicted_channels = model.predict(test_images.astype(np.float32), axes='YX')\n"," denoised_images= predicted_channels[...,0]\n"," segmented_images= (compute_labels(predicted_channels, threshold))\n","\n","#Here we save the results\n"," io.imsave(Result_folder+\"/\"+\"Predicted_denoised_\"+base_filename, denoised_images)\n"," io.imsave(Result_folder+\"/\"+\"Predicted_segmentation_\"+base_filename,segmented_images)\n"," \n","\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Qg31ghpfoNBD"},"source":["## **6.2. Assess predicted output**\r\n","---\r\n","\r\n","\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"CH7t08UooLba"},"source":["\r\n","# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\r\n","\r\n","\r\n","# This will display a randomly chosen dataset input and predicted output\r\n","random_choice = random.choice(os.listdir(Data_folder))\r\n","x = imread(Data_folder+\"/\"+random_choice)\r\n","\r\n","os.chdir(Result_folder)\r\n","y = imread(Result_folder+\"/\"+\"Predicted_denoised_\"+random_choice)\r\n","z = imread(Result_folder+\"/\"+\"Predicted_segmentation_\"+random_choice)\r\n","\r\n","norm = simple_norm(x, percent = 99)\r\n","\r\n","plt.figure(figsize=(30,15))\r\n","plt.subplot(1, 4, 1)\r\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\r\n","plt.axis('off');\r\n","plt.title(\"Input\")\r\n","\r\n","plt.subplot(1, 4, 2)\r\n","plt.imshow(y, interpolation='nearest', norm=norm, cmap='magma')\r\n","plt.axis('off');\r\n","plt.title(\"Predicted denoised image\")\r\n","\r\n","plt.subplot(1, 4, 3)\r\n","plt.imshow(z, interpolation='nearest', vmin=0, vmax=1, cmap='viridis')\r\n","plt.axis('off');\r\n","plt.title(\"Predicted segmentation\")\r\n","\r\n","plt.show()\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using DenoiSeg!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/SplineDist_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/SplineDist_2D_ZeroCostDL4Mic.ipynb index 103fb8b7..397429ca 100644 --- a/Colab_notebooks/Beta notebooks/SplineDist_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Beta notebooks/SplineDist_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"SplineDist_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **SplineDist (2D)**\n","---\n","\n","**SplineDist 2D** is a deep-learning method that can be used to segment objects from bioimages and was first published by [Mandal *et al.* in 2020, on biorXiv](https://www.biorxiv.org/content/10.1101/2020.10.27.357640v1). SplineDist uses a flexible and general representation by modelling objects as planar parametric spline curves.\n","\n"," **This particular notebook enables the segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist or U-net 3D notebooks instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**SplineDist: Automated Cell Segmentation with Spline Curves** from Mandal *et al.*, bioRxiv. (https://www.biorxiv.org/content/10.1101/2020.10.27.357640v1)\n","\n","\n","**The Original code** is freely available in GitHub:\n","https://gitlab.ebi.ac.uk/smandal/splinedist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," For SplineDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","print('TensorFlow version: '+tf.__version__)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install SplineDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["\n","Notebook_version = ['1.11.4']\n","\n","\n","#@markdown ##Install SplineDist and dependencies\n","#%tensorflow_version 1.x\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools # improves STARDIST performances\n","!pip install edt # improves STARDIST performances\n","!pip install wget\n","!pip install fpdf\n","!pip install PTable # Nice tables\n","\n","\n","\n","!git clone https://gitlab.ebi.ac.uk/smandal/splinedist\n","\n","import os\n","\n","os.chdir(\"/content/splinedist\")\n","\n","!python setup.py install\n","!python splinegenerator.py install\n","\n","import splinegenerator as sg\n","from splinedist.utils import phi_generator, grid_generator, get_contoursize_max, export_imagej_rois\n","from splinedist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available\n","from splinedist.matching import matching, matching_dataset\n","from splinedist.models import Config2D, SplineDist2D, SplineDistData2D\n","\n","os.chdir(\"/content\")\n","# ------- Variable specific to Stardist -------\n","from glob import glob\n","from tqdm import tqdm\n","from csbdeep.utils import Path, normalize\n","import numpy as np\n","\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","\n","from tifffile import imsave, imread\n","\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32, img_as_ubyte, img_as_float\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","import cv2\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","# def get_contoursize_percentile(Y_trn, percentile):\n","# # Percentile needs to be between 0 and 100\n","# contoursize = []\n","# for i in range(len(Y_trn)):\n","# mask = Y_trn[i]\n","# obj_list = np.unique(mask)\n","# obj_list = obj_list[1:] \n"," \n","# for j in range(len(obj_list)): \n","# mask_temp = mask.copy() \n","# mask_temp[mask_temp != obj_list[j]] = 0\n","# mask_temp[mask_temp > 0] = 1\n"," \n","# mask_temp = mask_temp.astype(np.uint8) \n","# contours,_ = cv2.findContours(mask_temp, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)\n","# # areas = [cv2.contourArea(cnt) for cnt in contours] \n","# # max_ind = np.argmax(areas)\n","# # contour = np.squeeze(contours[max_ind])\n","# contour = np.squeeze(contours[0])\n","# contour = np.reshape(contour,(-1,2))\n","# contour = np.append(contour,contour[0].reshape((-1,2)),axis=0)\n","# contoursize = np.append(contoursize,contour.shape[0])\n"," \n","# contoursize_percentile = np.percentile(contoursize, percentile) \n","# return contoursize_percentile\n","\n","\n","def get_contoursize_percentile_from_path(target_path, percentile = 99, show_histogram = False):\n"," # Percentile needs to be between 0 and 100\n"," contoursize = []\n"," Y_list = glob(target_path+\"/*.tif\") \n"," for y in tqdm(Y_list):\n"," Y_im = imread(y)\n"," Y_im = fill_label_holes(Y_im)\n"," obj_list = np.unique(Y_im)\n"," obj_list = obj_list[1:] \n"," \n"," for j in range(len(obj_list)): \n"," mask_temp = Y_im.copy() \n"," mask_temp[mask_temp != obj_list[j]] = 0\n"," mask_temp[mask_temp > 0] = 1\n"," \n"," mask_temp = mask_temp.astype(np.uint8) \n"," contours,_ = cv2.findContours(mask_temp, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)\n"," perimeter = cv2.arcLength(contours[0],True)\n"," contoursize = np.append(contoursize, perimeter)\n","\n"," contoursize_max = np.amax(contoursize) \n"," contoursize_percentile = np.percentile(contoursize, percentile)\n","\n"," if show_histogram:\n"," # Histogram display\n"," n, bins, patches = plt.hist(x=contoursize, bins='auto', color='#0504aa',\n"," alpha=0.7, rwidth=0.85)\n"," plt.grid(axis='y', alpha=0.75)\n"," plt.xlabel('Contour size')\n"," plt.ylabel('Frequency')\n"," plt.title('Contour size distribution')\n"," plt.text(200, 300, r'$Max='+str(round(contoursize_max,2))+'$')\n"," plt.text(200, 280, r'$'+str(percentile)+'th-per.='+str(round(contoursize_percentile,2))+'$')\n"," maxfreq = n.max();\n"," # Set a clean upper y-axis limit.\n"," plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10);\n","\n"," return contoursize_percentile\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","print('-----------------------------------------')\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","print('Notebook version: '+Notebook_version[0])\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 50-100 epochs, but a full training should run for up to 400 epochs. Evaluate the performance after training (see 5.). **Default value: 100**\n","\n","**`patch_size`:** Input the size of the patches use to train SplineDist 2D (length of a side). The value should be smaller or equal to the dimensions of the image. Make the patch size as large as possible and divisible by 8. **If your runtime crashes when you start the training, decrease your `patch_size`. Default value: 256** \n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 4**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**, **the default value is also used when set to 0**.\n","\n","**`contoursize`:** Define the size of contour to use (object perimeter). **Default value: the 99th percentile of the perimeter distribution found in the training dataset**, **the default value is also used when set to 0**.\n","\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`grid_parameter`:** increase this number if the object to segment are very large or decrease it if they are very small. **Default value: 2**\n","\n","**`number_of_control_points`:** choose the number of control points. **Default value: 16**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size value until the OOM error disappear.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","#trained_model = model_path \n","\n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 400#@param {type:\"number\"}\n","patch_size = 256#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","contoursize = 0#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","grid_parameter = 2#@param [1, 2, 4, 8, 16, 32] {type:\"raw\"}\n","number_of_control_points = 16#@param {type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\") \n"," batch_size = 4\n"," percentage_validation = 10\n"," grid_parameter = 2\n"," number_of_control_points = 16\n"," initial_learning_rate = 0.0003\n"," print('Estimating optimal contour size...')\n"," contoursize = get_contoursize_percentile_from_path(Training_target, 99, False)\n","\n","else:\n"," if (contoursize == 0):\n"," print('Estimating optimal contour size...')\n"," contoursize = get_contoursize_percentile_from_path(Training_target, 99, True)\n"," \n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","# Here we open will randomly chosen input and output image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check the image dimensions\n","\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","print('Loaded images (width, length) =', x.shape)\n","\n","\n"," \n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print(bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","if patch_size > 2048:\n"," patch_size = 2048\n"," print(bcolors.WARNING + \" Your image dimension is large; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","# Here we check that the patch_size is divisible by 16\n","if not patch_size % 16 == 0:\n"," patch_size = ((int(patch_size / 16)-1) * 16)\n"," print(bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","print(\"Patch size: \", patch_size)\n","print(\"Contour size: \", round(contoursize,2))\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', cmap=lbl_cmap)\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_SplineDist2D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here via random rotations, flips, and intensity changes.\n","\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 4 #@param {type:\"slider\", min:1, max:10, step:1}\n","\n","\n","def random_fliprot(img, mask): \n"," assert img.ndim >= mask.ndim\n"," axes = tuple(range(mask.ndim))\n"," perm = tuple(np.random.permutation(axes))\n"," img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) \n"," mask = mask.transpose(perm) \n"," for ax in axes: \n"," if np.random.rand() > 0.5:\n"," img = np.flip(img, axis=ax)\n"," mask = np.flip(mask, axis=ax)\n"," return img, mask \n","\n","def random_intensity_change(img):\n"," img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)\n"," return img\n","\n","\n","def augmenter(x, y):\n"," \"\"\"Augmentation of a single input/label image pair.\n"," x is an input image\n"," y is the corresponding ground-truth label image\n"," \"\"\"\n"," x, y = random_fliprot(x, y)\n"," x = random_intensity_change(x)\n"," # add some gaussian noise\n"," sig = 0.02*np.random.uniform(0,1)\n"," x = x + sig*np.random.normal(0,1,x.shape)\n"," return x, y\n","\n","\n","\n","if Use_Data_augmentation:\n"," augmenter = augmenter\n"," print(\"Data augmentation enabled\")\n","\n","\n","if not Use_Data_augmentation:\n"," augmenter = None\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a SplineDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist' + W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","\n","Training_source_dir = Training_source\n","Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","\n","#Normalize images and fill small label holes.\n","if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n"," sys.stdout.flush()\n","\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the training dataset into training images (90 %) and validation images (10 %). \n","#It is advisable to use 10 % of your training dataset for validation. This ensures the truthfull validation error value. If only few validation images are used network may choose too easy or too challenging images for validation. \n","# split training data (images and masks) into training images and validation images.\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","# Currently always False for stability\n","use_gpu = False and gputools_available()\n","\n","#Here we ensure that our network has a minimal number of steps\n","\n","if (Use_Default_Advanced_Parameters) or (number_of_steps == 0):\n"," # number_of_steps = (int(len(X)/batch_size)+1)\n"," number_of_steps = Image_X*Image_Y/(patch_size*patch_size)*(int(len(X)/batch_size)+1)\n"," if (Use_Data_augmentation):\n"," augmentation_factor = Multiply_dataset_by\n"," number_of_steps = number_of_steps * augmentation_factor\n","\n","print('Number of steps: '+str(number_of_steps))\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# # compute the size of the largest contour present in the image-set\n","# Now executed above\n","# if contoursize_max == 0:\n","# # contoursize_max = get_contoursize_max(Y_trn)\n","# contoursize_99p = get_contoursize_percentile(Y_trn, 99) # 99th percentile\n","\n","# # print('Maximum contour size: '+str(contoursize_max))\n","# print('99th percentile contour size: '+str(contoursize_99p))\n","\n","\n","conf = Config2D (\n"," use_gpu = use_gpu,\n"," train_batch_size = batch_size,\n"," n_channel_in = n_channel,\n"," train_patch_size = (patch_size, patch_size),\n"," grid = (grid_parameter, grid_parameter),\n"," train_learning_rate = initial_learning_rate,\n"," n_params = 2*number_of_control_points,\n"," contoursize_max = contoursize,\n",")\n","\n","\n","# Here we create a model according to section 5.3.\n","model = SplineDist2D(conf, name=model_name, basedir=model_path)\n","\n","os.chdir(model_path+'/'+model_name)\n","phi_generator(number_of_control_points, conf.contoursize_max)\n","grid_generator(number_of_control_points, conf.train_patch_size, conf.grid)\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","\n","\n","# --------------------- ---------------------- ------------------------\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(list(Y), np.median)\n","fov = np.array(model._axes_tile_overlap('YX'))\n","if any(median_size > fov):\n"," print(bcolors.WARNING+\"WARNING: median object size larger than field of view of the neural network.\")\n","print(conf)\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n"]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","\n","\n","\n","# Training the model. \n","\n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","\n","\n","print(\"Training done\")\n","\n","print(\"Network optimization in progress\")\n","#Here we optimize the network.\n","\n","model.optimize_thresholds(X_val, Y_val)\n","\n","print(\"Done\")\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","# ------ Add the PDF export code here --------\n","\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'SplineDist 2D'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and method:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)\n"," \n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
contoursize_max{5}
grid_parameter{6}
number_of_contour_points{7}
initial_learning_rate{8}
\n","\"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,round(contoursize,2),grid_parameter,number_of_control_points,initial_learning_rate)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_SplineDist2D.png').shape\n","pdf.image('/content/TrainingDataExample_SplineDist2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- SplineDist 2D: Mandal et al. \"SplineDist: Automated Cell Segmentation with Spline Curves. bioRxiv 2020.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","\n","# ------ Add the PDF export code here ------\n","\n","\n","\n","# Displaying the validation data\n","\n","Y_val_pred = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]\n"," for x in tqdm(X_val)]\n","\n","def plot_img_label(img, lbl, img_title=\"image\", lbl_title=\"label\", **kwargs):\n"," fig, (ai,al) = plt.subplots(1,2, figsize=(12,5), gridspec_kw=dict(width_ratios=(1.25,1)))\n"," im = ai.imshow(img, cmap='gray', clim=(0,1))\n"," ai.set_title(img_title) \n"," fig.colorbar(im, ax=ai)\n"," al.imshow(lbl, cmap=lbl_cmap)\n"," al.set_title(lbl_title)\n"," plt.tight_layout()\n","\n","plot_img_label(X_val[0],Y_val[0], lbl_title=\"label GT\")\n","plot_img_label(X_val[0],Y_val_pred[0], lbl_title=\"label Pred\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nRaaG02xZh_N"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** (IuO) metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","Here, the IuO is both calculated over the whole image and on a per-object basis. The value displayed below is the IuO value calculated over the entire image. The IuO value calculated on a per-object basis is used to calculate the other metrics displayed.\n","\n","“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. \n","\n","When a segmented object has an IuO value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.\n","\n","The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).\n","\n","For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.\n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n"," \n","#Normalize images.\n","\n","if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n","\n","model = SplineDist2D(None, name=QC_model_name, basedir=QC_model_path)\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","print('Running predictions...')\n","for i in tqdm(range(lenght_of_Z)):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name)\n"," labels, details = model.predict_instances(img)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, details)\n","\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\", \"false positive\", \"true positive\", \"false negative\", \"precision\", \"recall\", \"accuracy\", \"f1 score\", \"n_true\", \"n_pred\", \"mean_true_score\", \"mean_matched_score\", \"panoptic_quality\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," # Calculate the matching (with IoU threshold `thresh`) and all metrics\n","\n"," stats = matching(test_prediction, test_ground_truth_image, thresh=0.5)\n"," \n","\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])\n","\n","\n","from tabulate import tabulate\n","\n","df = pd.read_csv (QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\")\n","print(tabulate(df, headers='keys', tablefmt='psql'))\n","\n","\n","\n","from astropy.visualization import simple_norm\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file = os.listdir(Source_QC_folder)):\n"," \n","\n"," plt.figure(figsize=(25,5))\n"," source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n"," target_image = io.imread(os.path.join(Target_QC_folder, file), as_gray = True)\n"," prediction = io.imread(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\"+file, as_gray = True)\n","\n"," stats = matching(prediction, target_image, thresh=0.5)\n","\n"," target_image_mask = np.empty_like(target_image)\n"," target_image_mask[target_image > 0] = 255\n"," target_image_mask[target_image == 0] = 0\n"," \n"," prediction_mask = np.empty_like(prediction)\n"," prediction_mask[prediction > 0] = 255\n"," prediction_mask[prediction == 0] = 0\n","\n"," intersection = np.logical_and(target_image_mask, prediction_mask)\n"," union = np.logical_or(target_image_mask, prediction_mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," norm = simple_norm(source_image, percent = 99)\n","\n"," \n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, cmap='Greens')\n"," plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3 )));\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","\n","#Make a pdf summary of the QC results\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'SplineDist 2D'\n","\n","day = datetime.now()\n","datetime = str(day)[0:16]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate and Time: '+datetime\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","Header_2 = 'Information for your materials and methods:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/16), h = round(exp_size[0]/4))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n","pdf.ln(2)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(full_QC_model_path+'/Quality Control/Quality_Control for '+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," #image = header[0]\n"," #PvGT_IoU = header[1]\n"," fp = header[2]\n"," tp = header[3]\n"," fn = header[4]\n"," precision = header[5]\n"," recall = header[6]\n"," acc = header[7]\n"," f1 = header[8]\n"," n_true = header[9]\n"," n_pred = header[10]\n"," mean_true = header[11]\n"," mean_matched = header[12]\n"," panoptic = header[13]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(\"image #\",\"Prediction v. GT IoU\",'false pos.','true pos.','false neg.',precision,recall,acc,f1,n_true,n_pred,mean_true,mean_matched,panoptic)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," #image = row[0]\n"," PvGT_IoU = row[1]\n"," fp = row[2]\n"," tp = row[3]\n"," fn = row[4]\n"," precision = row[5]\n"," recall = row[6]\n"," acc = row[7]\n"," f1 = row[8]\n"," n_true = row[9]\n"," n_pred = row[10]\n"," mean_true = row[11]\n"," mean_matched = row[12]\n"," panoptic = row[13]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(str(i),str(round(float(PvGT_IoU),3)),fp,tp,fn,str(round(float(precision),3)),str(round(float(recall),3)),str(round(float(acc),3)),str(round(float(f1),3)),n_true,n_pred,str(round(float(mean_true),3)),str(round(float(mean_matched),3)),str(round(float(panoptic),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- StarDist 2D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["\n","\n","## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.\n","\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n","In SplineDist the following results can be exported:\n","\n","- The predicted mask images\n","- A tracking file that can easily be imported into Trackmate to track the nuclei (Stacks only).\n","- A CSV file that contains the number of nuclei detected per image (single image only). \n","- A CSV file that contains the coordinate the centre of each detected nuclei (single image only). \n","\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","#@markdown ###What outputs would you like to generate?\n","\n","Mask_images = True #@param {type:\"boolean\"}\n","Tracking_file = False #@param {type:\"boolean\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#single images\n","Data_folder = Data_folder+\"/*.tif\"\n","\n","\n","if Data_type == 1 :\n"," print(\"Single images are now being predicted\")\n"," \n"," X = sorted(glob(Data_folder))\n"," X = list(map(imread,X))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1] \n"," \n"," if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n"," if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n"," sys.stdout.flush()\n"," \n"," model = SplineDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n"," \n"," Nuclei_number = []\n","\n"," # modify the names to suitable form: path_images/image_numberX.tif\n"," FILEnames = []\n"," for m in names:\n"," m = Results_folder+'/Predicted_'+m\n"," FILEnames.append(m)\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension=[]\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n"," \n","\n"," # Save all ROIs and masks into results folder\n"," \n"," for i in tqdm(range(len(X))):\n","\n"," img = normalize(X[i], 1,99.8, axis = axis_norm)\n"," # labels, details = model.predict_instances(img, n_tiles=model._guess_n_tiles(img), show_tile_progress=False)\n"," os.chdir(full_Prediction_model_path)\n"," labels, details = model.predict_instances(img)\n"," \n"," os.chdir(Results_folder)\n","\n"," if Mask_images:\n"," imsave(FILEnames[i], labels)\n"," \n"," if Tracking_file:\n"," print(bcolors.WARNING+\"Tracking files are only generated when stacks are predicted\"+W) \n"," \n"," Nuclei_centre_coordinate = details['points']\n"," my_df2 = pd.DataFrame(Nuclei_centre_coordinate)\n"," my_df2.columns =['Y', 'X']\n"," \n"," my_df2.to_csv(Results_folder+'/'+name_no_extension[i]+'_object_centre.csv', index=False, header=True)\n","\n"," Nuclei_array = details['coord']\n"," Nuclei_array2 = [names[i], Nuclei_array.shape[0]]\n"," Nuclei_number.append(Nuclei_array2) \n","\n"," my_df = pd.DataFrame(Nuclei_number)\n"," my_df.to_csv(Results_folder+'/object_count.csv', index=False, header=False)\n"," \n","\n"," # One example is displayed\n","\n"," print(\"One example image is displayed bellow:\")\n"," plt.figure(figsize=(10,10))\n"," plt.imshow(img if img.ndim==2 else img[...,:3], clim=(0,1), cmap='gray')\n"," plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)\n"," plt.axis('off');\n"," plt.savefig(name_no_extension[i]+\"_overlay.tif\")\n","\n","if Data_type == 2 :\n"," print(\"Stacks are being predicted\")\n"," np.random.seed(42)\n"," lbl_cmap = random_label_cmap()\n"," Y = sorted(glob(Data_folder))\n"," X = list(map(imread,Y))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," #Load a pretrained network\n"," model = SplineDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension = []\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n","\n"," outputdir = Path(Results_folder)\n","\n","# Save all ROIs and images in Results folder.\n"," for num, i in enumerate(X):\n"," print(\"Performing prediction on: \"+names[num])\n","\n"," timelapse = np.stack(i)\n"," timelapse = normalize(timelapse, 1,99.8, axis=(0,)+tuple(1+np.array(axis_norm)))\n"," timelapse.shape\n","\n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n"," Tracking_stack = np.zeros((n_timepoint, timelapse.shape[2], timelapse.shape[1]))\n","\n","# Save the masks in the result folder\n"," if Mask_images or Tracking_file:\n"," for t in tqdm(range(n_timepoint)):\n"," img_t = timelapse[t]\n"," os.chdir(full_Prediction_model_path)\n"," labels, details = model.predict_instances(img_t) \n"," prediction_stack[t] = labels\n","\n","# Create a tracking file for trackmate\n","\n"," for point in details['points']:\n"," cv2.circle(Tracking_stack[t],tuple(point),0,(1), -1)\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," Tracking_stack_32 = img_as_float32(Tracking_stack, force_copy=False)\n"," Tracking_stack_8 = img_as_ubyte(Tracking_stack_32, force_copy=True)\n"," \n"," Tracking_stack_8_rot = np.rot90(Tracking_stack_8, axes=(1,2))\n"," Tracking_stack_8_rot_flip = np.fliplr(Tracking_stack_8_rot)\n","\n"," os.chdir(Results_folder)\n"," if Mask_images:\n"," imsave(names[num], prediction_stack_32)\n"," if Tracking_file:\n"," imsave(name_no_extension[num]+\"_tracking_file.tif\", Tracking_stack_8_rot_flip)\n","\n"," \n","\n","print(\"Predictions completed\") "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using SplineDist 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"SplineDist_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1gAo0gWChRPnAujXKIPapuw4s2oZ2i733","timestamp":1610725889082},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **SplineDist (2D)**\n","---\n","\n","**SplineDist 2D** is a deep-learning method that can be used to segment objects from bioimages and was first published by [Mandal *et al.* in 2020, on biorXiv](https://www.biorxiv.org/content/10.1101/2020.10.27.357640v1). SplineDist uses a flexible and general representation by modelling objects as planar parametric spline curves.\n","\n"," **This particular notebook enables the segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist or U-net 3D notebooks instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**SplineDist: Automated Cell Segmentation with Spline Curves** from Mandal *et al.*, bioRxiv. (https://www.biorxiv.org/content/10.1101/2020.10.27.357640v1)\n","\n","\n","**The Original code** is freely available in GitHub:\n","https://gitlab.ebi.ac.uk/smandal/splinedist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," For SplineDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","print('TensorFlow version: '+tf.__version__)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install SplineDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["\n","Notebook_version = ['1.12']\n","\n","\n","#@markdown ##Install SplineDist and dependencies\n","#%tensorflow_version 1.x\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools # improves STARDIST performances\n","!pip install edt # improves STARDIST performances\n","!pip install wget\n","!pip install fpdf\n","!pip install PTable # Nice tables\n","\n","\n","!git clone https://github.com/uhlmanngroup/splinedist\n","# !git clone https://gitlab.ebi.ac.uk/smandal/splinedist\n","\n","import os\n","\n","os.chdir(\"/content/splinedist\")\n","\n","!python setup.py install\n","!python splinegenerator.py install\n","\n","import splinegenerator as sg\n","from splinedist.utils import phi_generator, grid_generator, get_contoursize_max, export_imagej_rois\n","from splinedist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available\n","from splinedist.matching import matching, matching_dataset\n","from splinedist.models import Config2D, SplineDist2D, SplineDistData2D\n","\n","os.chdir(\"/content\")\n","# ------- Variable specific to Stardist -------\n","from glob import glob\n","from tqdm import tqdm\n","from csbdeep.utils import Path, normalize\n","import numpy as np\n","\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","\n","from tifffile import imsave, imread\n","\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32, img_as_ubyte, img_as_float\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","import cv2\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","# def get_contoursize_percentile(Y_trn, percentile):\n","# # Percentile needs to be between 0 and 100\n","# contoursize = []\n","# for i in range(len(Y_trn)):\n","# mask = Y_trn[i]\n","# obj_list = np.unique(mask)\n","# obj_list = obj_list[1:] \n"," \n","# for j in range(len(obj_list)): \n","# mask_temp = mask.copy() \n","# mask_temp[mask_temp != obj_list[j]] = 0\n","# mask_temp[mask_temp > 0] = 1\n"," \n","# mask_temp = mask_temp.astype(np.uint8) \n","# contours,_ = cv2.findContours(mask_temp, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)\n","# # areas = [cv2.contourArea(cnt) for cnt in contours] \n","# # max_ind = np.argmax(areas)\n","# # contour = np.squeeze(contours[max_ind])\n","# contour = np.squeeze(contours[0])\n","# contour = np.reshape(contour,(-1,2))\n","# contour = np.append(contour,contour[0].reshape((-1,2)),axis=0)\n","# contoursize = np.append(contoursize,contour.shape[0])\n"," \n","# contoursize_percentile = np.percentile(contoursize, percentile) \n","# return contoursize_percentile\n","\n","\n","def get_contoursize_percentile_from_path(target_path, percentile = 99, show_histogram = False):\n"," # Percentile needs to be between 0 and 100\n"," contoursize = []\n"," Y_list = glob(target_path+\"/*.tif\") \n"," for y in tqdm(Y_list):\n"," Y_im = imread(y)\n"," Y_im = fill_label_holes(Y_im)\n"," obj_list = np.unique(Y_im)\n"," obj_list = obj_list[1:] \n"," \n"," for j in range(len(obj_list)): \n"," mask_temp = Y_im.copy() \n"," mask_temp[mask_temp != obj_list[j]] = 0\n"," mask_temp[mask_temp > 0] = 1\n"," \n"," mask_temp = mask_temp.astype(np.uint8) \n"," contours,_ = cv2.findContours(mask_temp, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)\n"," perimeter = cv2.arcLength(contours[0],True)\n"," contoursize = np.append(contoursize, perimeter)\n","\n"," contoursize_max = np.amax(contoursize) \n"," contoursize_percentile = np.percentile(contoursize, percentile)\n","\n"," if show_histogram:\n"," # Histogram display\n"," n, bins, patches = plt.hist(x=contoursize, bins='auto', color='#0504aa',\n"," alpha=0.7, rwidth=0.85)\n"," plt.grid(axis='y', alpha=0.75)\n"," plt.xlabel('Contour size')\n"," plt.ylabel('Frequency')\n"," plt.title('Contour size distribution')\n"," plt.text(200, 300, r'$Max='+str(round(contoursize_max,2))+'$')\n"," plt.text(200, 280, r'$'+str(percentile)+'th-per.='+str(round(contoursize_percentile,2))+'$')\n"," maxfreq = n.max();\n"," # Set a clean upper y-axis limit.\n"," plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10);\n","\n"," return contoursize_percentile\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'SplineDist 2D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)\n"," \n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
contoursize_max{5}
grid_parameter{6}
number_of_contour_points{7}
initial_learning_rate{8}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,round(contoursize,2),grid_parameter,number_of_control_points,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_SplineDist2D.png').shape\n"," pdf.image('/content/TrainingDataExample_SplineDist2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- SplineDist 2D: Mandal et al. \"SplineDist: Automated Cell Segmentation with Spline Curves. bioRxiv 2020.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'SplineDist 2D'\n","\n"," day = datetime.now()\n"," date_time = str(day)[0:16]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate and Time: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/16), h = round(exp_size[0]/4))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/Quality_Control for '+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," #image = header[0]\n"," #PvGT_IoU = header[1]\n"," fp = header[2]\n"," tp = header[3]\n"," fn = header[4]\n"," precision = header[5]\n"," recall = header[6]\n"," acc = header[7]\n"," f1 = header[8]\n"," n_true = header[9]\n"," n_pred = header[10]\n"," mean_true = header[11]\n"," mean_matched = header[12]\n"," panoptic = header[13]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(\"image #\",\"Prediction v. GT IoU\",'false pos.','true pos.','false neg.',precision,recall,acc,f1,n_true,n_pred,mean_true,mean_matched,panoptic)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," #image = row[0]\n"," PvGT_IoU = row[1]\n"," fp = row[2]\n"," tp = row[3]\n"," fn = row[4]\n"," precision = row[5]\n"," recall = row[6]\n"," acc = row[7]\n"," f1 = row[8]\n"," n_true = row[9]\n"," n_pred = row[10]\n"," mean_true = row[11]\n"," mean_matched = row[12]\n"," panoptic = row[13]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(str(i),str(round(float(PvGT_IoU),3)),fp,tp,fn,str(round(float(precision),3)),str(round(float(recall),3)),str(round(float(acc),3)),str(round(float(f1),3)),n_true,n_pred,str(round(float(mean_true),3)),str(round(float(mean_matched),3)),str(round(float(panoptic),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- StarDist 2D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","\n","\n","print('-----------------------------------------')\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","print('Notebook version: '+Notebook_version[0])\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","# Exporting requirements.txt for local run\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 50-100 epochs, but a full training should run for up to 400 epochs. Evaluate the performance after training (see 5.). **Default value: 100**\n","\n","**`patch_size`:** Input the size of the patches use to train SplineDist 2D (length of a side). The value should be smaller or equal to the dimensions of the image. Make the patch size as large as possible and divisible by 8. **If your runtime crashes when you start the training, decrease your `patch_size`. Default value: 256** \n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 4**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**, **the default value is also used when set to 0**.\n","\n","**`contoursize`:** Define the size of contour to use (object perimeter). **Default value: the 99th percentile of the perimeter distribution found in the training dataset**, **the default value is also used when set to 0**.\n","\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`grid_parameter`:** increase this number if the object to segment are very large or decrease it if they are very small. **Default value: 2**\n","\n","**`number_of_control_points`:** choose the number of control points. **Default value: 16**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size value until the OOM error disappear.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","#trained_model = model_path \n","\n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 100#@param {type:\"number\"}\n","patch_size = 256#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","contoursize = 0#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","grid_parameter = 2#@param [1, 2, 4, 8, 16, 32] {type:\"raw\"}\n","number_of_control_points = 16#@param {type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\") \n"," batch_size = 4\n"," percentage_validation = 10\n"," grid_parameter = 2\n"," number_of_control_points = 16\n"," initial_learning_rate = 0.0003\n"," print('Estimating optimal contour size...')\n"," contoursize = get_contoursize_percentile_from_path(Training_target, 99, False)\n","\n","else:\n"," if (contoursize == 0):\n"," print('Estimating optimal contour size...')\n"," contoursize = get_contoursize_percentile_from_path(Training_target, 99, True)\n"," \n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","# Here we open will randomly chosen input and output image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check the image dimensions\n","\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","print('Loaded images (width, length) =', x.shape)\n","\n","\n"," \n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print(bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","if patch_size > 2048:\n"," patch_size = 2048\n"," print(bcolors.WARNING + \" Your image dimension is large; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","# Here we check that the patch_size is divisible by 16\n","if not patch_size % 16 == 0:\n"," patch_size = ((int(patch_size / 16)-1) * 16)\n"," print(bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","print(\"Patch size: \", patch_size)\n","print(\"Contour size: \", round(contoursize,2))\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', cmap=lbl_cmap)\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_SplineDist2D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here via random rotations, flips, and intensity changes.\n","\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 1 #@param {type:\"slider\", min:1, max:10, step:1}\n","\n","\n","def random_fliprot(img, mask): \n"," assert img.ndim >= mask.ndim\n"," axes = tuple(range(mask.ndim))\n"," perm = tuple(np.random.permutation(axes))\n"," img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) \n"," mask = mask.transpose(perm) \n"," for ax in axes: \n"," if np.random.rand() > 0.5:\n"," img = np.flip(img, axis=ax)\n"," mask = np.flip(mask, axis=ax)\n"," return img, mask \n","\n","def random_intensity_change(img):\n"," img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)\n"," return img\n","\n","\n","def augmenter(x, y):\n"," \"\"\"Augmentation of a single input/label image pair.\n"," x is an input image\n"," y is the corresponding ground-truth label image\n"," \"\"\"\n"," x, y = random_fliprot(x, y)\n"," x = random_intensity_change(x)\n"," # add some gaussian noise\n"," sig = 0.02*np.random.uniform(0,1)\n"," x = x + sig*np.random.normal(0,1,x.shape)\n"," return x, y\n","\n","\n","\n","if Use_Data_augmentation:\n"," augmenter = augmenter\n"," print(\"Data augmentation enabled\")\n","\n","\n","if not Use_Data_augmentation:\n"," augmenter = None\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a SplineDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist' + W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","\n","Training_source_dir = Training_source\n","Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","\n","#Normalize images and fill small label holes.\n","if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n"," sys.stdout.flush()\n","\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the training dataset into training images (90 %) and validation images (10 %). \n","#It is advisable to use 10 % of your training dataset for validation. This ensures the truthfull validation error value. If only few validation images are used network may choose too easy or too challenging images for validation. \n","# split training data (images and masks) into training images and validation images.\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","# Currently always False for stability\n","use_gpu = False and gputools_available()\n","\n","#Here we ensure that our network has a minimal number of steps\n","\n","if (Use_Default_Advanced_Parameters) or (number_of_steps == 0):\n"," # number_of_steps = (int(len(X)/batch_size)+1)\n"," number_of_steps = Image_X*Image_Y/(patch_size*patch_size)*(int(len(X)/batch_size)+1)\n"," if (Use_Data_augmentation):\n"," augmentation_factor = Multiply_dataset_by\n"," number_of_steps = number_of_steps * augmentation_factor\n","\n","print('Number of steps: '+str(number_of_steps))\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# # compute the size of the largest contour present in the image-set\n","# Now executed above\n","# if contoursize_max == 0:\n","# # contoursize_max = get_contoursize_max(Y_trn)\n","# contoursize_99p = get_contoursize_percentile(Y_trn, 99) # 99th percentile\n","\n","# # print('Maximum contour size: '+str(contoursize_max))\n","# print('99th percentile contour size: '+str(contoursize_99p))\n","\n","\n","conf = Config2D (\n"," use_gpu = use_gpu,\n"," train_batch_size = batch_size,\n"," n_channel_in = n_channel,\n"," train_patch_size = (patch_size, patch_size),\n"," grid = (grid_parameter, grid_parameter),\n"," train_learning_rate = initial_learning_rate,\n"," n_params = 2*number_of_control_points,\n"," contoursize_max = contoursize,\n",")\n","\n","\n","# Here we create a model according to section 5.3.\n","model = SplineDist2D(conf, name=model_name, basedir=model_path)\n","\n","os.chdir(model_path+'/'+model_name)\n","phi_generator(number_of_control_points, conf.contoursize_max)\n","grid_generator(number_of_control_points, conf.train_patch_size, conf.grid)\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","\n","\n","# --------------------- ---------------------- ------------------------\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(list(Y), np.median)\n","fov = np.array(model._axes_tile_overlap('YX'))\n","if any(median_size > fov):\n"," print(bcolors.WARNING+\"WARNING: median object size larger than field of view of the neural network.\")\n","print(conf)\n","\n","#Export pdf summary of training parameters\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n"]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","\n","\n","\n","# Training the model. \n","\n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","\n","\n","print(\"Training done\")\n","\n","print(\"Network optimization in progress\")\n","#Here we optimize the network.\n","\n","model.optimize_thresholds(X_val, Y_val)\n","\n","print(\"Done\")\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","# Displaying the validation data\n","\n","Y_val_pred = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]\n"," for x in tqdm(X_val)]\n","\n","def plot_img_label(img, lbl, img_title=\"image\", lbl_title=\"label\", **kwargs):\n"," fig, (ai,al) = plt.subplots(1,2, figsize=(12,5), gridspec_kw=dict(width_ratios=(1.25,1)))\n"," im = ai.imshow(img, cmap='gray', clim=(0,1))\n"," ai.set_title(img_title) \n"," fig.colorbar(im, ax=ai)\n"," al.imshow(lbl, cmap=lbl_cmap)\n"," al.set_title(lbl_title)\n"," plt.tight_layout()\n","\n","plot_img_label(X_val[0],Y_val[0], lbl_title=\"label GT\")\n","plot_img_label(X_val[0],Y_val_pred[0], lbl_title=\"label Pred\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** (IuO) metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","Here, the IuO is both calculated over the whole image and on a per-object basis. The value displayed below is the IuO value calculated over the entire image. The IuO value calculated on a per-object basis is used to calculate the other metrics displayed.\n","\n","“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. \n","\n","When a segmented object has an IuO value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.\n","\n","The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).\n","\n","For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.\n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n"," \n","#Normalize images.\n","\n","if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n","\n","model = SplineDist2D(None, name=QC_model_name, basedir=QC_model_path)\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","print('Running predictions...')\n","for i in tqdm(range(lenght_of_Z)):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name)\n"," labels, details = model.predict_instances(img)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, details)\n","\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\", \"false positive\", \"true positive\", \"false negative\", \"precision\", \"recall\", \"accuracy\", \"f1 score\", \"n_true\", \"n_pred\", \"mean_true_score\", \"mean_matched_score\", \"panoptic_quality\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," # Calculate the matching (with IoU threshold `thresh`) and all metrics\n","\n"," stats = matching(test_prediction, test_ground_truth_image, thresh=0.5)\n"," \n","\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])\n","\n","\n","from tabulate import tabulate\n","\n","df = pd.read_csv (QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\")\n","print(tabulate(df, headers='keys', tablefmt='psql'))\n","\n","\n","\n","from astropy.visualization import simple_norm\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file = os.listdir(Source_QC_folder)):\n"," \n","\n"," plt.figure(figsize=(25,5))\n"," source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n"," target_image = io.imread(os.path.join(Target_QC_folder, file), as_gray = True)\n"," prediction = io.imread(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\"+file, as_gray = True)\n","\n"," stats = matching(prediction, target_image, thresh=0.5)\n","\n"," target_image_mask = np.empty_like(target_image)\n"," target_image_mask[target_image > 0] = 255\n"," target_image_mask[target_image == 0] = 0\n"," \n"," prediction_mask = np.empty_like(prediction)\n"," prediction_mask[prediction > 0] = 255\n"," prediction_mask[prediction == 0] = 0\n","\n"," intersection = np.logical_and(target_image_mask, prediction_mask)\n"," union = np.logical_or(target_image_mask, prediction_mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," norm = simple_norm(source_image, percent = 99)\n","\n"," \n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, cmap='Greens')\n"," plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3 )));\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["\n","\n","## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.\n","\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n","In SplineDist the following results can be exported:\n","\n","- The predicted mask images\n","- A tracking file that can easily be imported into Trackmate to track the nuclei (Stacks only).\n","- A CSV file that contains the number of nuclei detected per image (single image only). \n","- A CSV file that contains the coordinate the centre of each detected nuclei (single image only). \n","\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Stacks #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","#@markdown ###What outputs would you like to generate?\n","\n","Mask_images = True #@param {type:\"boolean\"}\n","Tracking_file = False #@param {type:\"boolean\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#single images\n","Data_folder = Data_folder+\"/*.tif\"\n","\n","\n","if Data_type == 1 :\n"," print(\"Single images are now being predicted\")\n"," \n"," X = sorted(glob(Data_folder))\n"," X = list(map(imread,X))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1] \n"," \n"," if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n"," if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n"," sys.stdout.flush()\n"," \n"," model = SplineDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n"," \n"," Nuclei_number = []\n","\n"," # modify the names to suitable form: path_images/image_numberX.tif\n"," FILEnames = []\n"," for m in names:\n"," m = Results_folder+'/Predicted_'+m\n"," FILEnames.append(m)\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension=[]\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n"," \n","\n"," # Save all ROIs and masks into results folder\n"," \n"," for i in tqdm(range(len(X))):\n","\n"," img = normalize(X[i], 1,99.8, axis = axis_norm)\n"," # labels, details = model.predict_instances(img, n_tiles=model._guess_n_tiles(img), show_tile_progress=False)\n"," os.chdir(full_Prediction_model_path)\n"," labels, details = model.predict_instances(img)\n"," \n"," os.chdir(Results_folder)\n","\n"," if Mask_images:\n"," imsave(FILEnames[i], labels)\n"," \n"," if Tracking_file:\n"," print(bcolors.WARNING+\"Tracking files are only generated when stacks are predicted\"+W) \n"," \n"," Nuclei_centre_coordinate = details['points']\n"," my_df2 = pd.DataFrame(Nuclei_centre_coordinate)\n"," my_df2.columns =['Y', 'X']\n"," \n"," my_df2.to_csv(Results_folder+'/'+name_no_extension[i]+'_object_centre.csv', index=False, header=True)\n","\n"," Nuclei_array = details['coord']\n"," Nuclei_array2 = [names[i], Nuclei_array.shape[0]]\n"," Nuclei_number.append(Nuclei_array2) \n","\n"," my_df = pd.DataFrame(Nuclei_number)\n"," my_df.to_csv(Results_folder+'/object_count.csv', index=False, header=False)\n"," \n","\n"," # One example is displayed\n","\n"," print(\"One example image is displayed bellow:\")\n"," plt.figure(figsize=(10,10))\n"," plt.imshow(img if img.ndim==2 else img[...,:3], clim=(0,1), cmap='gray')\n"," plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)\n"," plt.axis('off');\n"," plt.savefig(name_no_extension[i]+\"_overlay.tif\")\n","\n","if Data_type == 2 :\n"," print(\"Stacks are being predicted\")\n"," np.random.seed(42)\n"," lbl_cmap = random_label_cmap()\n"," Y = sorted(glob(Data_folder))\n"," X = list(map(imread,Y))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," #Load a pretrained network\n"," model = SplineDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension = []\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n","\n"," outputdir = Path(Results_folder)\n","\n","# Save all ROIs and images in Results folder.\n"," for num, i in enumerate(X):\n"," print(\"Performing prediction on: \"+names[num])\n","\n"," timelapse = np.stack(i)\n"," timelapse = normalize(timelapse, 1,99.8, axis=(0,)+tuple(1+np.array(axis_norm)))\n"," timelapse.shape\n","\n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n"," Tracking_stack = np.zeros((n_timepoint, timelapse.shape[2], timelapse.shape[1]))\n","\n","# Save the masks in the result folder\n"," if Mask_images or Tracking_file:\n"," for t in tqdm(range(n_timepoint)):\n"," img_t = timelapse[t]\n"," os.chdir(full_Prediction_model_path)\n"," labels, details = model.predict_instances(img_t) \n"," prediction_stack[t] = labels\n","\n","# Create a tracking file for trackmate\n","\n"," for point in details['points']:\n"," cv2.circle(Tracking_stack[t],tuple(point),0,(1), -1)\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," Tracking_stack_32 = img_as_float32(Tracking_stack, force_copy=False)\n"," Tracking_stack_8 = img_as_ubyte(Tracking_stack_32, force_copy=True)\n"," \n"," Tracking_stack_8_rot = np.rot90(Tracking_stack_8, axes=(1,2))\n"," Tracking_stack_8_rot_flip = np.fliplr(Tracking_stack_8_rot)\n","\n"," os.chdir(Results_folder)\n"," if Mask_images:\n"," imsave('Predicted_'+names[num], prediction_stack_32)\n"," if Tracking_file:\n"," imsave(name_no_extension[num]+\"_tracking_file.tif\", Tracking_stack_8_rot_flip)\n","\n"," \n","print('---------------------')\n","print(\"Predictions completed.\") "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using SplineDist 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb index 34d22693..aafa35c8 100644 --- a/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CARE_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1ocbvEKFFBCLymjK-3IQUQ7Lkb7hoQQym","timestamp":1601991456561},{"file_id":"1mqcexfPBaIWuvMWWbJZUFtPoZoJJwrEA","timestamp":1589278334507},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **CARE: Content-aware image restoration (2D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 2D dataset. If you are interested in restoring 3D dataset, you should use the CARE 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["\n","#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = ['1.11']\n","\n","\n","#@markdown ##Install CARE and dependencies\n","\n","#Libraries contains information of certain topics. \n","#For example the tifffile library contains information on how to handle tif-files.\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install memory_profiler\n","!pip install fpdf\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","%tensorflow_version 1.x\n","\n","import tensorflow \n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"BLmBseWbRvxL"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5). **Default value: 50**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 100** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.tif\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.tif\"\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 50#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # in pixels\n","number_of_patches = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 400#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#Here we define the percentage to use for validation\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","# The shape of the images.\n","x = imread(InputFile)\n","y = imread(OutputFile)\n","\n","print('Loaded Input images (number, width, length) =', x.shape)\n","print('Loaded Output images (number, width, length) =', y.shape)\n","print(\"Parameters initiated.\")\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_CARE2D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_-CEUqlS8o3M"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"qe9zvEJ9qOH2"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by [Augmentor.](https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"zmtlu9YU266X","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4kb3xSZMRzxU"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"mlN-VNOgR-nr","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"-A4ipz8gs3Ew"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"LKYRNhA5Qnis","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\"+W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","\n","raw_data = data.RawData.from_folder(\n"," basepath=base,\n"," source_dirs=[Training_source_dir], \n"," target_dir=Training_target_dir, \n"," axes='CYX', \n"," pattern='*.tif*')\n","\n","X, Y, XY_axes = data.create_patches(\n"," raw_data, \n"," patch_filter=None, \n"," patch_size=(patch_size,patch_size), \n"," n_patches_per_image=number_of_patches)\n","\n","print ('Creating 2D training dataset')\n","training_path = model_path+\"/rawdata\"\n","rawdata1 = training_path+\".npz\"\n","np.savez(training_path,X=X, Y=Y, axes=XY_axes)\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","%memit \n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we create the configuration file\n","\n","config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"biXiR017C4UU","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model_training.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'CARE 2D'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and methods:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n"," if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if flip_left_right != 0 or flip_top_bottom != 0:\n"," aug_text = aug_text+'\\n- flipping'\n"," if random_zoom_magnification != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if random_distortion != 0:\n"," aug_text = aug_text+'\\n- random distortion'\n"," if image_shear != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," if skew_image != 0:\n"," aug_text = aug_text+'\\n- image skewing'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
number_of_patches{2}
batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
\n","\"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_CARE2D.png').shape\n","pdf.image('/content/TrainingDataExample_CARE2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","if Use_Data_augmentation:\n"," ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","loss_displayed = False"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","**Note: Plots of the losses will be shown in a linear and in a log scale. This can help visualise changes in the losses at different magnitudes. However, note that if the losses are negative the plot on the log scale will be empty. This is not an error.**"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","loss_displayed = True\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"RZOPCVN0qcYb"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"Nh8MlX3sqd_7","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model_training.predict(img, axes='YX')\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","#Make a pdf summary of the QC results\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'CARE 2D'\n","#model_name = os.path.basename(full_QC_model_path)\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n","if os.path.exists(full_QC_model_path+'Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n","pdf.ln(2)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n","pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(full_QC_model_path+'Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"9ZmST3JRq-Ho","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","\n","#Activate the pretrained model. \n","model_training = CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","# creates a loop, creating filenames and saving them\n","for filename in os.listdir(Data_folder):\n"," img = imread(os.path.join(Data_folder,filename))\n"," restored = model_training.predict(img, axes='YX')\n"," os.chdir(Result_folder)\n"," imsave(filename,restored)\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","plt.figure(figsize=(16,8))\n","\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Input')\n","\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Predicted output');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw"},"source":["\n","#**Thank you for using CARE 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"CARE_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hMjEc-Ex7j-jeYGclaPw2x3OgbkeC6Bl","timestamp":1610626439596},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **CARE: Content-aware image restoration (2D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 2D dataset. If you are interested in restoring 3D dataset, you should use the CARE 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["\n","#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","#@markdown ##Install CARE and dependencies\n","\n","#Libraries contains information of certain topics. \n","#For example the tifffile library contains information on how to handle tif-files.\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install memory_profiler\n","!pip install fpdf\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","%tensorflow_version 1.x\n","\n","import tensorflow \n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","!pip freeze > requirements.txt\n","\n","#Create a pdf document with training summary\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," # save FPDF() class into a \n"," # variable pdf \n"," #from datetime import datetime\n","\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'CARE 2D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n"," if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if flip_left_right != 0 or flip_top_bottom != 0:\n"," aug_text = aug_text+'\\n- flipping'\n"," if random_zoom_magnification != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if random_distortion != 0:\n"," aug_text = aug_text+'\\n- random distortion'\n"," if image_shear != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," if skew_image != 0:\n"," aug_text = aug_text+'\\n- image skewing'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
number_of_patches{2}
batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_CARE2D.png').shape\n"," pdf.image('/content/TrainingDataExample_CARE2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," if augmentation:\n"," ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","#Make a pdf summary of the QC results\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'CARE 2D'\n"," #model_name = os.path.basename(full_QC_model_path)\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.', align='L')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n","\n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5). **Default value: 50**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 100** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.tif\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.tif\"\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 80#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # in pixels\n","number_of_patches = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 400#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#Here we define the percentage to use for validation\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","# The shape of the images.\n","x = imread(InputFile)\n","y = imread(OutputFile)\n","\n","print('Loaded Input images (number, width, length) =', x.shape)\n","print('Loaded Output images (number, width, length) =', y.shape)\n","print(\"Parameters initiated.\")\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_CARE2D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by [Augmentor.](https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\"+W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","\n","raw_data = data.RawData.from_folder(\n"," basepath=base,\n"," source_dirs=[Training_source_dir], \n"," target_dir=Training_target_dir, \n"," axes='CYX', \n"," pattern='*.tif*')\n","\n","X, Y, XY_axes = data.create_patches(\n"," raw_data, \n"," patch_filter=None, \n"," patch_size=(patch_size,patch_size), \n"," n_patches_per_image=number_of_patches)\n","\n","print ('Creating 2D training dataset')\n","training_path = model_path+\"/rawdata\"\n","rawdata1 = training_path+\".npz\"\n","np.savez(training_path,X=X, Y=Y, axes=XY_axes)\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","%memit \n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we create the configuration file\n","\n","config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model_training.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","loss_displayed = False"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","**Note: Plots of the losses will be shown in a linear and in a log scale. This can help visualise changes in the losses at different magnitudes. However, note that if the losses are negative the plot on the log scale will be empty. This is not an error.**"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","loss_displayed = True\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model_training.predict(img, axes='YX')\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","\n","#Activate the pretrained model. \n","model_training = CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","# creates a loop, creating filenames and saving them\n","for filename in os.listdir(Data_folder):\n"," img = imread(os.path.join(Data_folder,filename))\n"," restored = model_training.predict(img, axes='YX')\n"," os.chdir(Result_folder)\n"," imsave(filename,restored)\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bShxBHY4vFFd"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"6b2t6SLQvIBO"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","plt.figure(figsize=(16,8))\n","\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Input')\n","\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Predicted output');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using CARE 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb index feba000f..4b352405 100644 --- a/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"CARE_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **CARE: Content-aware image restoration (3D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"BDhmUgqCStlm"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"3u2mXn3XsWzd"},"source":["Notebook_version = ['1.11']\n","\n","\n","#@markdown ##Install CARE and dependencies\n","\n","\n","%tensorflow_version 1.x\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","\n","import tensorflow\n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, normalize, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 40**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**When choosing the patch_size and patch_height, the values should be i) large enough that they will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 200** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"cellView":"form","id":"ewpNJ_I0Mv47"},"source":["\n","#@markdown ###Path to training images:\n","\n","# base folder of GT and low images\n","base = \"/content\"\n","\n","# low SNR images\n","Training_source = \"\" #@param {type:\"string\"}\n","lowfile = Training_source+\"/*.tif\"\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","GTfile = Training_target+\"/*.tif\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# create the training data file into model_path folder.\n","training_data = model_path+\"/my_training_data.npz\"\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 50#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # pixels in\n","patch_height = 8#@param {type:\"number\"}\n","number_of_patches = 200#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 300#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n"," \n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","#Load one randomly chosen training source file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","\n","#Load one randomly chosen training target file\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Low SNR image (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('High SNR image (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_CARE3D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n","\n","**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**"]},{"cell_type":"code","metadata":{"cellView":"form","id":"htqjkJWt5J_8"},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"cellView":"form","id":"8vPkzEBNamE4"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"cellView":"form","id":"WMJnGJpCMa4y"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","if Use_Data_augmentation == True:\n"," Training_source = Saving_path+'/augmented_source'\n"," Training_target = Saving_path+'/augmented_target'\n","\n","raw_data = RawData.from_folder (\n"," basepath = base,\n"," source_dirs = [Training_source],\n"," target_dir = Training_target,\n"," axes = 'ZYX',\n"," pattern='*.tif*'\n",")\n","X, Y, XY_axes = create_patches (\n"," raw_data = raw_data,\n"," patch_size = (patch_height,patch_size,patch_size),\n"," n_patches_per_image = number_of_patches, \n"," save_file = training_data,\n",")\n","\n","assert X.shape == Y.shape\n","print(\"shape of X,Y =\", X.shape)\n","print(\"axes of X,Y =\", XY_axes)\n","\n","%memit \n","print ('Creating 3D training dataset')\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(training_data, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","#Plot example patches\n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here, we create the default Config object which sets the hyperparameters of the network training.\n","\n","config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSB Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"cellView":"form","id":"j_Qm5JBmlvJg"},"source":["#@markdown ##Start Training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model_training.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'CARE 3D'\n","#model_name = 'little_CARE_test'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and method:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","#text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
patch_height{2}
number_of_patches{3}
batch_size{4}
number_of_steps{5}
percentage_validation{6}
initial_learning_rate{7}
\n","\"\"\".format(number_of_epochs,patch_size,patch_height,number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(32, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_CARE3D.png').shape\n","pdf.image('/content/TrainingDataExample_CARE3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","# if Use_Data_augmentation:\n","# ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n","# pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w8Q_uYGgiico"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"zazOZ3wDx0zQ"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","id":"vMzSP50kMv5p"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"nAs4Wni7VYbq"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n","pdResults.head()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'CARE 3D'\n","#model_name = os.path.basename(QC_model_folder)\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n","pdf.ln(3)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," mSSIM_SvsGT = header[3]\n"," NRMSE_PvsGT = header[4]\n"," NRMSE_SvsGT = header[5]\n"," PSNR_PvsGT = header[6]\n"," PSNR_SvsGT = header[7]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," mSSIM_SvsGT = row[3]\n"," NRMSE_PvsGT = row[4]\n"," NRMSE_SvsGT = row[5]\n"," PSNR_PvsGT = row[6]\n"," PSNR_SvsGT = row[7]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"cellView":"form","id":"Am2JSmpC0frj"},"source":["\n","#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n"," \n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model. \n","model=CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Restoring images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX') \n","\n","print(\"Images saved into the result folder:\", Result_folder)\n","\n","#Display an example\n","\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","z_mid_plane = int(x.shape[0] / 2)+1\n","\n","@interact\n","def show_results(file=os.listdir(Data_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n"," x = imread(Data_folder+\"/\"+file)\n"," y = imread(Result_folder+\"/\"+file)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Noisy Input (single Z plane)');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Prediction (single Z plane)');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using CARE 3D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"CARE_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hMjEc-Ex7j-jeYGclaPw2x3OgbkeC6Bl","timestamp":1610626439596},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **CARE: Content-aware image restoration (3D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install CARE and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","#@markdown ##Install CARE and dependencies\n","\n","\n","%tensorflow_version 1.x\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","\n","import tensorflow\n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, normalize, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'CARE 3D'\n"," #model_name = 'little_CARE_test'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
patch_height{2}
number_of_patches{3}
batch_size{4}
number_of_steps{5}
percentage_validation{6}
initial_learning_rate{7}
\n"," \"\"\".format(number_of_epochs,patch_size,patch_height,number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(32, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_CARE3D.png').shape\n"," pdf.image('/content/TrainingDataExample_CARE3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'CARE 3D'\n"," #model_name = os.path.basename(QC_model_folder)\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," mSSIM_SvsGT = header[3]\n"," NRMSE_PvsGT = header[4]\n"," NRMSE_SvsGT = header[5]\n"," PSNR_PvsGT = header[6]\n"," PSNR_SvsGT = header[7]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," mSSIM_SvsGT = row[3]\n"," NRMSE_PvsGT = row[4]\n"," NRMSE_SvsGT = row[5]\n"," PSNR_PvsGT = row[6]\n"," PSNR_SvsGT = row[7]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","!pip freeze > requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 40**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**When choosing the patch_size and patch_height, the values should be i) large enough that they will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 200** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"cellView":"form","id":"ewpNJ_I0Mv47"},"source":["\n","#@markdown ###Path to training images:\n","\n","# base folder of GT and low images\n","base = \"/content\"\n","\n","# low SNR images\n","Training_source = \"\" #@param {type:\"string\"}\n","lowfile = Training_source+\"/*.tif\"\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","GTfile = Training_target+\"/*.tif\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# create the training data file into model_path folder.\n","training_data = model_path+\"/my_training_data.npz\"\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 80#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # pixels in\n","patch_height = 8#@param {type:\"number\"}\n","number_of_patches = 200#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 300#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n"," \n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","#Load one randomly chosen training source file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","\n","#Load one randomly chosen training target file\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Low SNR image (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('High SNR image (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_CARE3D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n","\n","**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**"]},{"cell_type":"code","metadata":{"cellView":"form","id":"htqjkJWt5J_8"},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","if Use_Data_augmentation == True:\n"," Training_source = Saving_path+'/augmented_source'\n"," Training_target = Saving_path+'/augmented_target'\n","\n","raw_data = RawData.from_folder (\n"," basepath = base,\n"," source_dirs = [Training_source],\n"," target_dir = Training_target,\n"," axes = 'ZYX',\n"," pattern='*.tif*'\n",")\n","X, Y, XY_axes = create_patches (\n"," raw_data = raw_data,\n"," patch_size = (patch_height,patch_size,patch_size),\n"," n_patches_per_image = number_of_patches, \n"," save_file = training_data,\n",")\n","\n","assert X.shape == Y.shape\n","print(\"shape of X,Y =\", X.shape)\n","print(\"axes of X,Y =\", XY_axes)\n","\n","%memit \n","print ('Creating 3D training dataset')\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(training_data, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","#Plot example patches\n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here, we create the default Config object which sets the hyperparameters of the network training.\n","\n","config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSB Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["#@markdown ##Start Training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model_training.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","#Create a pdf document with training summary\n","pdf_export(trained=True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"zazOZ3wDx0zQ"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","id":"vMzSP50kMv5p"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n","pdResults.head()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"cellView":"form","id":"Am2JSmpC0frj"},"source":["\n","#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n"," \n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model. \n","model=CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Restoring images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX') \n","\n","print(\"Images saved into the result folder:\", Result_folder)\n","\n","#Display an example\n","\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","z_mid_plane = int(x.shape[0] / 2)+1\n","\n","@interact\n","def show_results(file=os.listdir(Data_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n"," x = imread(Data_folder+\"/\"+file)\n"," y = imread(Result_folder+\"/\"+file)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Noisy Input (single Z plane)');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Prediction (single Z plane)');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using CARE 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/ChangeLog.txt b/Colab_notebooks/ChangeLog.txt index 9608d1dc..1bbc0528 100644 --- a/Colab_notebooks/ChangeLog.txt +++ b/Colab_notebooks/ChangeLog.txt @@ -8,6 +8,18 @@ Latest releases available here: https://github.com/HenriquesLab/ZeroCostDL4Mic/releases +————————————————————————————————————————————————————————— +ZeroCostDL4Mic v1.12 + +Major changes: + +- PDF export of training session is now done before the training starts, so records are kept in case training fails to finalise. +- StarDist 2D is now compatible with RGB image input +- Beta notebooks: new notebooks available: SplineDist and 3D-RCAN. SplineDist also includes compatibility with RGB input. + ++ general minor notebook optimisations. + + ————————————————————————————————————————————————————————— ZeroCostDL4Mic v1.11 diff --git a/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb index 1d95e531..111204b9 100644 --- a/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CycleGAN_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1V02Qd1PuJ2RECkl24136fhLizyX2KkrL","timestamp":1602673365922},{"file_id":"1mqcexfPBaIWuvMWWbJZUFtPoZoJJwrEA","timestamp":1589278334507},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **CycleGAN**\n","\n","---\n","\n","CycleGAN is a method that can capture the characteristics of one image domain and learn how these characteristics can be translated into another image domain, all in the absence of any paired training examples. It was first published by [Zhu *et al.* in 2017](https://arxiv.org/abs/1703.10593). Unlike pix2pix, the image transformation performed does not require paired images for training (unsupervised learning) and is made possible here by using a set of two Generative Adversarial Networks (GANs) that learn to transform images both from the first domain to the second and vice-versa.\n","\n"," **This particular notebook enables unpaired image-to-image translation. If your dataset is paired, you should also consider using the pix2pix notebook.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks** from Zhu *et al.* published in arXiv in 2018 (https://arxiv.org/abs/1703.10593)\n","\n","The source code of the CycleGAN PyTorch implementation can be found in: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"N3azwKB9O0oW"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","metadata":{"id":"ByW6Vqdn9sYV","cellView":"form"},"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\n","#BSD License\n","\n","#For pix2pix software\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\n","#BSD License\n","\n","#For dcgan.torch software\n","\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," To train CycleGAN, **you only need two folders containing PNG images**. The images do not need to be paired.\n","\n","While you do not need paired images to train CycleGAN, if possible, **we strongly recommend that you generate a paired dataset. This means that the same image needs to be acquired in the two conditions. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","\n"," Please note that you currently can **only use .png files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset (non-matching images) **\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset (matching images)**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install CycleGAN and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = ['1.11']\n","\n","\n","\n","#@markdown ##Install CycleGAN and dependencies\n","\n","\n","#------- Code from the cycleGAN demo notebook starts here -------\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","\n","\n","!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","import os\n","os.chdir('pytorch-CycleGAN-and-pix2pix/')\n","!pip install -r requirements.txt\n","!pip install fpdf\n","\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","\n","from skimage.util import img_as_int\n","\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"BLmBseWbRvxL"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** CycleGAN divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 4. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","metadata":{"id":"pIrTwJjzwV-D","cellView":"form"},"source":["\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.png\"\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 1#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","\n","\n","#To use Cyclegan we need to organise the data in a way the model can understand\n","\n","Saving_path= \"/content/\"+model_name\n","#Saving_path= model_path+\"/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","TrainA_Folder = Saving_path+\"/trainA\"\n","if os.path.exists(TrainA_Folder):\n"," shutil.rmtree(TrainA_Folder)\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/trainB\"\n","if os.path.exists(TrainB_Folder):\n"," shutil.rmtree(TrainB_Folder)\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imageio.imread(Training_source+\"/\"+random_choice)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","random_choice_2 = random.choice(os.listdir(Training_target))\n","y = imageio.imread(Training_target+\"/\"+random_choice_2)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_cycleGAN.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"FX6uxFvI-CsQ"},"source":["## **3.2. Data augmentation**\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CwMaFU1T-GtN"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by flipping the patches. \n","\n"," By default data augmentation is enabled."]},{"cell_type":"code","metadata":{"id":"kLtHIATT-0un","cellView":"form"},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"v-leE8pEWRkn"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CycleGAN model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"CbOcS3wiWV9w","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path_A = os.path.join(pretrained_model_path, \"latest_net_G_A.pth\")\n"," h5_file_path_B = os.path.join(pretrained_model_path, \"latest_net_G_B.pth\")\n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"-A4ipz8gs3Ew"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from 3. to prepare the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"_V2ujGB60gDv","cellView":"form"},"source":["#@markdown ##Prepare the data for training\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","\n","for f in os.listdir(Training_source):\n"," shutil.copyfile(Training_source+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","for files in os.listdir(Training_target):\n"," shutil.copyfile(Training_target+\"/\"+files, TrainB_Folder+\"/\"+files)\n","\n","#---------------------------------------------------------------------\n","\n","# CycleGAN use number of EPOCH withouth lr decay and number of EPOCH with lr decay\n","\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","print(\"Data ready for training\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session."]},{"cell_type":"code","metadata":{"id":"eBD50tAgv5qf","cellView":"form"},"source":["\n","#@markdown ##Start training\n","\n","start = time.time()\n","\n","os.chdir(\"/content\")\n","\n","#--------------------------------- Command line inputs to change CycleGAN paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --no_flip\n","\n","if Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n"," \n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train --no_flip\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","# save FPDF() class into a \n","# variable pdf \n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'cycleGAN'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and method:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','torch']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and an least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by default'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
initial_learning_rate{3}
\n","\"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,initial_learning_rate)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_cycleGAN.png').shape\n","pdf.image('/content/TrainingDataExample_cycleGAN.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","# if Use_Data_augmentation:\n","# ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n","# pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","Unfortunately loss functions curve are not very informative for GAN network. Therefore we perform the QC here using a test dataset.\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"PhcOwcgH3JAD"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"E4Yp7ogh3NGD"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"cell_type":"markdown","metadata":{"id":"1yauWCc78HKD"},"source":[" CycleGAN save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n","\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n"]},{"cell_type":"code","metadata":{"id":"2nBPucJdK3KS","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Image_type = \"Grayscale\" #@param [\"Grayscale\", \"RGB\"]\n","\n","# average function\n","def Average(lst): \n"," return sum(lst) / len(lst) \n","\n","\n","# Create a quality control folder\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_QC= \"/content/\"+QC_model_name\n","\n","if os.path.exists(Saving_path_QC):\n"," shutil.rmtree(Saving_path_QC)\n","os.makedirs(Saving_path_QC)\n","\n","Saving_path_QC_folder = Saving_path_QC+\"_images\"\n","\n","if os.path.exists(Saving_path_QC_folder):\n"," shutil.rmtree(Saving_path_QC_folder)\n","os.makedirs(Saving_path_QC_folder)\n","\n","\n","#Here we copy and rename the all the checkpoint to be analysed\n","\n","for f in os.listdir(full_QC_model_path):\n"," shortname = f[:-6]\n"," shortname = shortname + \".pth\"\n"," if f.endswith(\"net_G_A.pth\"):\n"," shutil.copyfile(full_QC_model_path+f, Saving_path_QC+\"/\"+shortname)\n","\n","\n","for files in os.listdir(Source_QC_folder):\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, Saving_path_QC_folder+\"/\"+files)\n"," \n","\n","# This will find the image dimension of a randomly chosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = int(min(Image_Y, Image_X))\n","\n","Nb_Checkpoint = len(os.listdir(Saving_path_QC))\n","\n","print(Nb_Checkpoint)\n","\n","\n","\n","## Initiate list\n","\n","Checkpoint_list = []\n","Average_ssim_score_list = []\n","\n","\n","for j in range(1, len(os.listdir(Saving_path_QC))+1):\n"," checkpoints = j*5\n","\n"," if checkpoints == Nb_Checkpoint*5:\n"," checkpoints = \"latest\"\n","\n","\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n","\n"," Checkpoint_list.append(checkpoints)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n"," os.chdir(\"/content\")\n","\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_QC_folder\" --name \"$QC_model_name\" --model test --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$QC_prediction_results\" --checkpoints_dir \"/content/\"\n","\n","#-----------------------------------------------------------------------------------\n","\n","#Here we need to move the data again and remove all the unnecessary folders\n","\n"," Checkpoint_name = \"test_\"+str(checkpoints)\n","\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n"," QC_results_images_files = os.listdir(QC_results_images)\n","\n"," for f in QC_results_images_files: \n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n","\n"," os.chdir(\"/content\") \n","\n"," #Here we clean up the extra files\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n","\n","\n","#-------------------------------- QC for RGB ------------------------------------\n"," if Image_type == \"RGB\":\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n"," random_choice = random.choice(os.listdir(Source_QC_folder))\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\n"," \n"," \n"," # Initiate list\n"," ssim_score_list = [] \n","\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," shortname_no_PNG = i[:-4]\n"," \n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," \n"," # -------------------------------- Prediction --------------------------------\n"," \n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," #--------------------------- Here we normalise using histograms matching--------------------------------\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n"," \n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," \n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","\n","\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\n","\n"," if Image_type == \"Grayscale\":\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n","\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," \n"," \n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," ssim_score_list = []\n"," shortname_no_PNG = i[:-4]\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT_raw = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n"," \n"," test_GT = test_GT_raw[:,:,2]\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," test_source = test_source_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," test_prediction = test_prediction_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," \n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","# All data is now processed saved\n"," \n","\n","# -------------------------------- Display --------------------------------\n","\n","# Display the IoV vs Threshold plot\n","plt.figure(figsize=(20,5))\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n","plt.title('Checkpoints vs. SSIM')\n","plt.ylabel('SSIM')\n","plt.xlabel('Checkpoints')\n","plt.legend()\n","plt.savefig(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","\n","# -------------------------------- Display RGB --------------------------------\n","\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","if Image_type == \"RGB\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n","#Setting up colours\n"," \n"," cmap = None\n","\n"," plt.figure(figsize=(10,10))\n","\n","# Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_GT, cmap = cmap)\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(Source_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_Source, cmap = cmap)\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n","\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n","\n"," plt.imshow(img_Prediction, cmap = cmap)\n"," plt.title('Prediction',fontsize=15)\n","\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","# -------------------------------- Display Grayscale --------------------------------\n","\n","if Image_type == \"Grayscale\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n"," \n","\n"," plt.figure(figsize=(15,15))\n","\n"," cmap = None\n"," \n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=True, pilmode=\"RGB\")\n","\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99), cmap = 'gray')\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real.png\"))\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," \n"," \n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n"," \n","\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n","\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","#Make a pdf summary of the QC results\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'cycleGAN'\n","\n","\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png').shape\n","pdf.image(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(2)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n","if Image_type == 'RGB':\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/5), h = round(exp_size[0]/5))\n","if Image_type == 'Grayscale':\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","for checkpoint in os.listdir(full_QC_model_path+'Quality Control'):\n"," if os.path.isdir(os.path.join(full_QC_model_path,'Quality Control',checkpoint)):\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(70, 5, txt = 'Metrics for checkpoint: '+ str(checkpoint), align='L', ln=1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'Quality Control/'+str(checkpoint)+'/QC_metrics_'+QC_model_name+str(checkpoint)+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," pdf.write_html(html)\n"," pdf.ln(2)\n"," else:\n"," continue\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\"."]},{"cell_type":"code","metadata":{"id":"yb3suNkfpNA9","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","import glob\n","import os.path\n","\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# Here we check that checkpoint exist, if not the closest one will be chosen \n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G_A.pth')))\n","print(Nb_Checkpoint)\n","\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n"," \n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","Saving_path_Data_folder = Saving_path_prediction+\"/testA\"\n","\n","if os.path.exists(Saving_path_Data_folder):\n"," shutil.rmtree(Saving_path_Data_folder)\n","os.makedirs(Saving_path_Data_folder)\n","\n","for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, Saving_path_Data_folder+\"/\"+files)\n","\n","\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","\n","\n","#Here we copy and rename the checkpoint to be used\n","\n","shutil.copyfile(full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G_A.pth\", full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G.pth\")\n","\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","print(Image_min_dim)\n","\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_Data_folder\" --name \"$Prediction_model_name\" --model test --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import os\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","\n","\n","random_choice_no_extension = os.path.splitext(random_choice)\n","\n","\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real.png\")\n","\n","\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake.png\")\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Prediction')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw"},"source":["\n","#**Thank you for using CycleGAN!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CycleGAN_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611059046709},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **CycleGAN**\n","\n","---\n","\n","CycleGAN is a method that can capture the characteristics of one image domain and learn how these characteristics can be translated into another image domain, all in the absence of any paired training examples. It was first published by [Zhu *et al.* in 2017](https://arxiv.org/abs/1703.10593). Unlike pix2pix, the image transformation performed does not require paired images for training (unsupervised learning) and is made possible here by using a set of two Generative Adversarial Networks (GANs) that learn to transform images both from the first domain to the second and vice-versa.\n","\n"," **This particular notebook enables unpaired image-to-image translation. If your dataset is paired, you should also consider using the pix2pix notebook.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks** from Zhu *et al.* published in arXiv in 2018 (https://arxiv.org/abs/1703.10593)\n","\n","The source code of the CycleGAN PyTorch implementation can be found in: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jqvkQQkcuMmM"},"source":["# **License**\r\n","\r\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","id":"vCihhAzluRvI"},"source":["#@markdown ##Double click to see the license information\r\n","\r\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\r\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\r\n","\r\n","\r\n","\r\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\r\n","\r\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\r\n","#All rights reserved.\r\n","\r\n","#Redistribution and use in source and binary forms, with or without\r\n","#modification, are permitted provided that the following conditions are met:\r\n","\r\n","#* Redistributions of source code must retain the above copyright notice, this\r\n","# list of conditions and the following disclaimer.\r\n","\r\n","#* Redistributions in binary form must reproduce the above copyright notice,\r\n","# this list of conditions and the following disclaimer in the documentation\r\n","# and/or other materials provided with the distribution.\r\n","\r\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\r\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\r\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\r\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\r\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\r\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\r\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\r\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\r\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\r\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\r\n","\r\n","\r\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\r\n","#BSD License\r\n","\r\n","#For pix2pix software\r\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\r\n","#All rights reserved.\r\n","\r\n","#Redistribution and use in source and binary forms, with or without\r\n","#modification, are permitted provided that the following conditions are met:\r\n","\r\n","#* Redistributions of source code must retain the above copyright notice, this\r\n","# list of conditions and the following disclaimer.\r\n","\r\n","#* Redistributions in binary form must reproduce the above copyright notice,\r\n","# this list of conditions and the following disclaimer in the documentation\r\n","# and/or other materials provided with the distribution.\r\n","\r\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\r\n","#BSD License\r\n","\r\n","#For dcgan.torch software\r\n","\r\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\r\n","\r\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\r\n","\r\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\r\n","\r\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\r\n","\r\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\r\n","\r\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["#**0. Before getting started**\n","---\n"," To train CycleGAN, **you only need two folders containing PNG images**. The images do not need to be paired.\n","\n","While you do not need paired images to train CycleGAN, if possible, **we strongly recommend that you generate a paired dataset. This means that the same image needs to be acquired in the two conditions. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","\n"," Please note that you currently can **only use .png files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset (non-matching images)**\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset (matching images)**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install CycleGAN and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","\n","#@markdown ##Install CycleGAN and dependencies\n","\n","\n","#------- Code from the cycleGAN demo notebook starts here -------\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","\n","\n","!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","import os\n","os.chdir('pytorch-CycleGAN-and-pix2pix/')\n","!pip install -r requirements.txt\n","!pip install fpdf\n","\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","\n","from skimage.util import img_as_int\n","\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'cycleGAN'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell\n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and an least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by default'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
initial_learning_rate{3}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_cycleGAN.png').shape\n"," pdf.image('/content/TrainingDataExample_cycleGAN.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'cycleGAN'\n","\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png').shape\n"," pdf.image(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," if Image_type == 'RGB':\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/5), h = round(exp_size[0]/5))\n"," if Image_type == 'Grayscale':\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," for checkpoint in os.listdir(full_QC_model_path+'Quality Control'):\n"," if os.path.isdir(os.path.join(full_QC_model_path,'Quality Control',checkpoint)):\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(70, 5, txt = 'Metrics for checkpoint: '+ str(checkpoint), align='L', ln=1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'Quality Control/'+str(checkpoint)+'/QC_metrics_'+QC_model_name+str(checkpoint)+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," pdf.write_html(html)\n"," pdf.ln(2)\n"," else:\n"," continue\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** CycleGAN divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 4. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.png\"\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 1#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","\n","\n","#To use Cyclegan we need to organise the data in a way the model can understand\n","\n","Saving_path= \"/content/\"+model_name\n","#Saving_path= model_path+\"/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","TrainA_Folder = Saving_path+\"/trainA\"\n","if os.path.exists(TrainA_Folder):\n"," shutil.rmtree(TrainA_Folder)\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/trainB\"\n","if os.path.exists(TrainB_Folder):\n"," shutil.rmtree(TrainB_Folder)\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imageio.imread(Training_source+\"/\"+random_choice)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","random_choice_2 = random.choice(os.listdir(Training_target))\n","y = imageio.imread(Training_target+\"/\"+random_choice_2)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_cycleGAN.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by flipping the patches. \n","\n"," By default data augmentation is enabled."]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CycleGAN model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path_A = os.path.join(pretrained_model_path, \"latest_net_G_A.pth\")\n"," h5_file_path_B = os.path.join(pretrained_model_path, \"latest_net_G_B.pth\")\n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from 3. to prepare the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Prepare the data for training\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","\n","for f in os.listdir(Training_source):\n"," shutil.copyfile(Training_source+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","for files in os.listdir(Training_target):\n"," shutil.copyfile(Training_target+\"/\"+files, TrainB_Folder+\"/\"+files)\n","\n","#---------------------------------------------------------------------\n","\n","# CycleGAN use number of EPOCH withouth lr decay and number of EPOCH with lr decay\n","\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","print(\"Data ready for training\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["\n","#@markdown ##Start training\n","\n","start = time.time()\n","\n","os.chdir(\"/content\")\n","\n","#--------------------------------- Command line inputs to change CycleGAN paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --no_flip\n","\n","if Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n"," \n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train --no_flip\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","# Save training summary as pdf\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","Unfortunately loss functions curve are not very informative for GAN network. Therefore we perform the QC here using a test dataset.\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"1Wext8woxt_F"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1CFbjvTpx5C3"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"cell_type":"markdown","metadata":{"id":"q8tCfAadx96X"},"source":[" CycleGAN save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\r\n","\r\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\r\n","\r\n","**1. The SSIM (structural similarity) map** \r\n","\r\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \r\n","\r\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\r\n","\r\n","**The output below shows the SSIM maps with the mSSIM**\r\n","\r\n","**2. The RSE (Root Squared Error) map** \r\n","\r\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\r\n","\r\n","\r\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\r\n","\r\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\r\n","\r\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\r\n","\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"q2T4t8NNyDZ6"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\r\n","\r\n","Source_QC_folder = \"\" #@param{type:\"string\"}\r\n","Target_QC_folder = \"\" #@param{type:\"string\"}\r\n","\r\n","Image_type = \"Grayscale\" #@param [\"Grayscale\", \"RGB\"]\r\n","\r\n","# average function\r\n","def Average(lst): \r\n"," return sum(lst) / len(lst) \r\n","\r\n","\r\n","# Create a quality control folder\r\n","\r\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\r\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\r\n","\r\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\r\n","\r\n","# List images in Source_QC_folder\r\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \r\n","random_choice = random.choice(os.listdir(Source_QC_folder))\r\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\r\n","\r\n","#Find image XY dimension\r\n","Image_Y = x.shape[0]\r\n","Image_X = x.shape[1]\r\n","\r\n","Image_min_dim = min(Image_Y, Image_X)\r\n","\r\n","\r\n","# Here we need to move the data to be analysed so that cycleGAN can find them\r\n","\r\n","Saving_path_QC= \"/content/\"+QC_model_name\r\n","\r\n","if os.path.exists(Saving_path_QC):\r\n"," shutil.rmtree(Saving_path_QC)\r\n","os.makedirs(Saving_path_QC)\r\n","\r\n","Saving_path_QC_folder = Saving_path_QC+\"_images\"\r\n","\r\n","if os.path.exists(Saving_path_QC_folder):\r\n"," shutil.rmtree(Saving_path_QC_folder)\r\n","os.makedirs(Saving_path_QC_folder)\r\n","\r\n","\r\n","#Here we copy and rename the all the checkpoint to be analysed\r\n","\r\n","for f in os.listdir(full_QC_model_path):\r\n"," shortname = f[:-6]\r\n"," shortname = shortname + \".pth\"\r\n"," if f.endswith(\"net_G_A.pth\"):\r\n"," shutil.copyfile(full_QC_model_path+f, Saving_path_QC+\"/\"+shortname)\r\n","\r\n","\r\n","for files in os.listdir(Source_QC_folder):\r\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, Saving_path_QC_folder+\"/\"+files)\r\n"," \r\n","\r\n","# This will find the image dimension of a randomly chosen image in Source_QC_folder \r\n","random_choice = random.choice(os.listdir(Source_QC_folder))\r\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\r\n","\r\n","#Find image XY dimension\r\n","Image_Y = x.shape[0]\r\n","Image_X = x.shape[1]\r\n","\r\n","Image_min_dim = int(min(Image_Y, Image_X))\r\n","\r\n","Nb_Checkpoint = len(os.listdir(Saving_path_QC))\r\n","\r\n","print(Nb_Checkpoint)\r\n","\r\n","\r\n","\r\n","## Initiate list\r\n","\r\n","Checkpoint_list = []\r\n","Average_ssim_score_list = []\r\n","\r\n","\r\n","for j in range(1, len(os.listdir(Saving_path_QC))+1):\r\n"," checkpoints = j*5\r\n","\r\n"," if checkpoints == Nb_Checkpoint*5:\r\n"," checkpoints = \"latest\"\r\n","\r\n","\r\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\r\n","\r\n"," Checkpoint_list.append(checkpoints)\r\n","\r\n","\r\n"," # Create a quality control/Prediction Folder\r\n","\r\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\r\n","\r\n"," if os.path.exists(QC_prediction_results):\r\n"," shutil.rmtree(QC_prediction_results)\r\n","\r\n"," os.makedirs(QC_prediction_results)\r\n","\r\n","\r\n","\r\n","#---------------------------- Predictions are performed here ----------------------\r\n","\r\n"," os.chdir(\"/content\")\r\n","\r\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_QC_folder\" --name \"$QC_model_name\" --model test --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$QC_prediction_results\" --checkpoints_dir \"/content/\"\r\n","\r\n","#-----------------------------------------------------------------------------------\r\n","\r\n","#Here we need to move the data again and remove all the unnecessary folders\r\n","\r\n"," Checkpoint_name = \"test_\"+str(checkpoints)\r\n","\r\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\r\n","\r\n"," QC_results_images_files = os.listdir(QC_results_images)\r\n","\r\n"," for f in QC_results_images_files: \r\n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\r\n","\r\n"," os.chdir(\"/content\") \r\n","\r\n"," #Here we clean up the extra files\r\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\r\n","\r\n","\r\n","#-------------------------------- QC for RGB ------------------------------------\r\n"," if Image_type == \"RGB\":\r\n","# List images in Source_QC_folder\r\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \r\n"," random_choice = random.choice(os.listdir(Source_QC_folder))\r\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\r\n","\r\n"," def ssim(img1, img2):\r\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\r\n","\r\n","# Open and create the csv file that will contain all the QC metrics\r\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\r\n"," writer = csv.writer(file)\r\n","\r\n"," # Write the header in the csv file\r\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\r\n"," \r\n"," \r\n"," # Initiate list\r\n"," ssim_score_list = [] \r\n","\r\n","\r\n"," # Let's loop through the provided dataset in the QC folders\r\n","\r\n","\r\n"," for i in os.listdir(Source_QC_folder):\r\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\r\n"," print('Running QC on: '+i)\r\n","\r\n"," shortname_no_PNG = i[:-4]\r\n"," \r\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\r\n"," test_GT = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\r\n","\r\n"," # -------------------------------- Source test data --------------------------------\r\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\r\n"," \r\n"," \r\n"," # -------------------------------- Prediction --------------------------------\r\n"," \r\n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\r\n"," \r\n"," #--------------------------- Here we normalise using histograms matching--------------------------------\r\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\r\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\r\n"," \r\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\r\n","\r\n"," # Calculate the SSIM maps\r\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\r\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\r\n","\r\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\r\n","\r\n"," #Save ssim_maps\r\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\r\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\r\n"," \r\n"," \r\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\r\n","\r\n"," #Here we calculate the ssim average for each image in each checkpoints\r\n","\r\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\r\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\r\n","\r\n","\r\n","\r\n","\r\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\r\n","\r\n"," if Image_type == \"Grayscale\":\r\n"," def ssim(img1, img2):\r\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\r\n","\r\n","\r\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\r\n","\r\n","\r\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\r\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\r\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\r\n","\r\n","\r\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\r\n"," \r\n"," if dtype is not None:\r\n"," x = x.astype(dtype,copy=False)\r\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\r\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\r\n"," eps = dtype(eps)\r\n","\r\n"," try:\r\n"," import numexpr\r\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\r\n"," except ImportError:\r\n"," x = (x - mi) / ( ma - mi + eps )\r\n","\r\n"," if clip:\r\n"," x = np.clip(x,0,1)\r\n","\r\n"," return x\r\n","\r\n"," def norm_minmse(gt, x, normalize_gt=True):\r\n"," \r\n"," if normalize_gt:\r\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\r\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\r\n"," #x = x - np.mean(x)\r\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\r\n"," #gt = gt - np.mean(gt)\r\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\r\n"," return gt, scale * x\r\n","\r\n","# Open and create the csv file that will contain all the QC metrics\r\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\r\n"," writer = csv.writer(file)\r\n","\r\n"," # Write the header in the csv file\r\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \r\n","\r\n"," \r\n"," \r\n"," # Let's loop through the provided dataset in the QC folders\r\n","\r\n","\r\n"," for i in os.listdir(Source_QC_folder):\r\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\r\n"," print('Running QC on: '+i)\r\n","\r\n"," ssim_score_list = []\r\n"," shortname_no_PNG = i[:-4]\r\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\r\n"," test_GT_raw = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\r\n"," \r\n"," test_GT = test_GT_raw[:,:,2]\r\n","\r\n"," # -------------------------------- Source test data --------------------------------\r\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\r\n"," \r\n"," test_source = test_source_raw[:,:,2]\r\n","\r\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\r\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\r\n","\r\n"," # -------------------------------- Prediction --------------------------------\r\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\r\n"," \r\n"," test_prediction = test_prediction_raw[:,:,2]\r\n","\r\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\r\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \r\n","\r\n","\r\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\r\n","\r\n"," # Calculate the SSIM maps\r\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\r\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\r\n","\r\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\r\n","\r\n"," #Save ssim_maps\r\n"," \r\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\r\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\r\n"," \r\n"," # Calculate the Root Squared Error (RSE) maps\r\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\r\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\r\n","\r\n"," # Save SE maps\r\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\r\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\r\n","\r\n","\r\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\r\n","\r\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\r\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\r\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\r\n"," \r\n"," # We can also measure the peak signal to noise ratio between the images\r\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\r\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\r\n","\r\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\r\n","\r\n"," #Here we calculate the ssim average for each image in each checkpoints\r\n","\r\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\r\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\r\n","\r\n","\r\n","# All data is now processed saved\r\n"," \r\n","\r\n","# -------------------------------- Display --------------------------------\r\n","\r\n","# Display the IoV vs Threshold plot\r\n","plt.figure(figsize=(20,5))\r\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\r\n","plt.title('Checkpoints vs. SSIM')\r\n","plt.ylabel('SSIM')\r\n","plt.xlabel('Checkpoints')\r\n","plt.legend()\r\n","plt.savefig(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\r\n","plt.show()\r\n","\r\n","\r\n","\r\n","# -------------------------------- Display RGB --------------------------------\r\n","\r\n","from ipywidgets import interact\r\n","import ipywidgets as widgets\r\n","\r\n","\r\n","if Image_type == \"RGB\":\r\n"," random_choice_shortname_no_PNG = shortname_no_PNG\r\n","\r\n"," @interact\r\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\r\n","\r\n"," random_choice_shortname_no_PNG = file[:-4]\r\n","\r\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\r\n"," df2 = df1.set_index(\"image #\", drop = False)\r\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\r\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\r\n","\r\n","#Setting up colours\r\n"," \r\n"," cmap = None\r\n","\r\n"," plt.figure(figsize=(10,10))\r\n","\r\n","# Target (Ground-truth)\r\n"," plt.subplot(3,3,1)\r\n"," plt.axis('off')\r\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=False, pilmode=\"RGB\")\r\n"," plt.imshow(img_GT, cmap = cmap)\r\n"," plt.title('Target',fontsize=15)\r\n","\r\n","# Source\r\n"," plt.subplot(3,3,2)\r\n"," plt.axis('off')\r\n"," img_Source = imageio.imread(os.path.join(Source_QC_folder, file), as_gray=False, pilmode=\"RGB\")\r\n"," plt.imshow(img_Source, cmap = cmap)\r\n"," plt.title('Source',fontsize=15)\r\n","\r\n","#Prediction\r\n"," plt.subplot(3,3,3)\r\n"," plt.axis('off')\r\n","\r\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\r\n","\r\n"," plt.imshow(img_Prediction, cmap = cmap)\r\n"," plt.title('Prediction',fontsize=15)\r\n","\r\n","\r\n","#SSIM between GT and Source\r\n"," plt.subplot(3,3,5)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False)\r\n","\r\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n","\r\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\r\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\r\n"," plt.title('Target vs. Source',fontsize=15)\r\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\r\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\r\n","\r\n","#SSIM between GT and Prediction\r\n"," plt.subplot(3,3,6)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False) \r\n","\r\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n","\r\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\r\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\r\n"," plt.title('Target vs. Prediction',fontsize=15)\r\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\r\n"," plt.savefig(full_QC_model_path+'Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\r\n","\r\n","# -------------------------------- Display Grayscale --------------------------------\r\n","\r\n","if Image_type == \"Grayscale\":\r\n"," random_choice_shortname_no_PNG = shortname_no_PNG\r\n","\r\n"," @interact\r\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\r\n","\r\n"," random_choice_shortname_no_PNG = file[:-4]\r\n","\r\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\r\n"," df2 = df1.set_index(\"image #\", drop = False)\r\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\r\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\r\n","\r\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\r\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\r\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\r\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\r\n"," \r\n","\r\n"," plt.figure(figsize=(15,15))\r\n","\r\n"," cmap = None\r\n"," \r\n"," # Target (Ground-truth)\r\n"," plt.subplot(3,3,1)\r\n"," plt.axis('off')\r\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=True, pilmode=\"RGB\")\r\n","\r\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99), cmap = 'gray')\r\n"," plt.title('Target',fontsize=15)\r\n","\r\n","# Source\r\n"," plt.subplot(3,3,2)\r\n"," plt.axis('off')\r\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real.png\"))\r\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\r\n"," plt.title('Source',fontsize=15)\r\n","\r\n","#Prediction\r\n"," plt.subplot(3,3,3)\r\n"," plt.axis('off')\r\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\r\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\r\n"," plt.title('Prediction',fontsize=15)\r\n","\r\n","#Setting up colours\r\n"," cmap = plt.cm.CMRmap\r\n","\r\n","#SSIM between GT and Source\r\n"," plt.subplot(3,3,5)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False)\r\n","\r\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\r\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\r\n","\r\n"," \r\n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\r\n"," plt.title('Target vs. Source',fontsize=15)\r\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\r\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\r\n","\r\n","#SSIM between GT and Prediction\r\n"," plt.subplot(3,3,6)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False) \r\n"," \r\n"," \r\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\r\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\r\n","\r\n"," \r\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\r\n"," plt.title('Target vs. Prediction',fontsize=15)\r\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\r\n","\r\n","#Root Squared Error between GT and Source\r\n"," plt.subplot(3,3,8)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False)\r\n","\r\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\r\n"," \r\n","\r\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\r\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\r\n"," plt.title('Target vs. Source',fontsize=15)\r\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\r\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\r\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\r\n","\r\n","#Root Squared Error between GT and Prediction\r\n"," plt.subplot(3,3,9)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False)\r\n","\r\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n","\r\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\r\n","\r\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\r\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\r\n"," plt.title('Target vs. Prediction',fontsize=15)\r\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\r\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\r\n","\r\n","\r\n","#Make a pdf summary of the QC results\r\n","\r\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\"."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","import glob\n","import os.path\n","\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# Here we check that checkpoint exist, if not the closest one will be chosen \n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G_A.pth')))\n","print(Nb_Checkpoint)\n","\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n"," \n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","Saving_path_Data_folder = Saving_path_prediction+\"/testA\"\n","\n","if os.path.exists(Saving_path_Data_folder):\n"," shutil.rmtree(Saving_path_Data_folder)\n","os.makedirs(Saving_path_Data_folder)\n","\n","for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, Saving_path_Data_folder+\"/\"+files)\n","\n","\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","\n","\n","#Here we copy and rename the checkpoint to be used\n","\n","shutil.copyfile(full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G_A.pth\", full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G.pth\")\n","\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","print(Image_min_dim)\n","\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_Data_folder\" --name \"$Prediction_model_name\" --model test --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SXqS_EhByhQ7"},"source":["## **6.2. Inspect the predicted output**\r\n","---\r\n","\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"64emoATwylxM"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\r\n","import os\r\n","# This will display a randomly chosen dataset input and predicted output\r\n","random_choice = random.choice(os.listdir(Data_folder))\r\n","\r\n","\r\n","random_choice_no_extension = os.path.splitext(random_choice)\r\n","\r\n","\r\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real.png\")\r\n","\r\n","\r\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake.png\")\r\n","\r\n","f=plt.figure(figsize=(16,8))\r\n","plt.subplot(1,2,1)\r\n","plt.imshow(x, interpolation='nearest')\r\n","plt.title('Input')\r\n","plt.axis('off');\r\n","\r\n","plt.subplot(1,2,2)\r\n","plt.imshow(y, interpolation='nearest')\r\n","plt.title('Prediction')\r\n","plt.axis('off');\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\r\n","#**Thank you for using CycleGAN!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb index f669bbc1..be4b7e17 100644 --- a/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Deep-STORM_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1w95RljMrg15FLDRnEJiLIEa-lW-jEjQS","timestamp":1602684895691},{"file_id":"169qcwQo-yw15PwoGatXAdBvjs4wt_foD","timestamp":1592147948265},{"file_id":"1gjRCgDORKi_GNBu4QnVCBkSWrfPtqL-E","timestamp":1588525976305},{"file_id":"1DFy6aCi1XAVdjA5KLRZirB2aMZkMFdv-","timestamp":1587998755430},{"file_id":"1NpzigQoXGy3GFdxh4_jvG1PnBfyrcpBs","timestamp":1587569988032},{"file_id":"1jdI540qAfMSQwjnMhoAFkGJH9EbHwNSf","timestamp":1587486196143}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"FpCtYevLHfl4"},"source":["# **Deep-STORM (2D)**\n","\n","---\n","\n","Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).\n","\n","Deep-STORM has **two key advantages**:\n","- SMLM reconstruction at high density of emitters\n","- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.\n","\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)\n","\n","And source code found in: https://github.com/EliasNehme/Deep-STORM\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"wyzTn3IcHq6Y"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"bEy4EBXHHyAX"},"source":["#**0. Before getting started**\n","---\n"," Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"E04mOlG_H5Tz"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"F_tjlGzsH-Dn"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"gn-LaaNNICqL","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","import tensorflow as tf\n","# if tf.__version__ != '2.2.0':\n","# !pip install tensorflow==2.2.0\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tnP7wM79IKW-"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"1R-7Fo34_gOd","cellView":"form"},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jRnQZWSZhArJ"},"source":["# **2. Install Deep-STORM and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"kSrZMo3X_NhO","cellView":"form"},"source":["Notebook_version = ['1.11']\n","\n","#@markdown ##Install Deep-STORM and dependencies\n","\n","\n","# %% Model definition + helper functions\n","\n","!pip install fpdf\n","# Import keras modules and libraries\n","from tensorflow import keras\n","from tensorflow.keras.models import Model\n","from tensorflow.keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization, Layer\n","from tensorflow.keras.callbacks import Callback\n","from tensorflow.keras import backend as K\n","from tensorflow.keras import optimizers, losses\n","\n","from tensorflow.keras.preprocessing.image import ImageDataGenerator\n","from tensorflow.keras.callbacks import ModelCheckpoint\n","from tensorflow.keras.callbacks import ReduceLROnPlateau\n","from skimage.transform import warp\n","from skimage.transform import SimilarityTransform\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from scipy.signal import fftconvolve\n","\n","# Import common libraries\n","import tensorflow as tf\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import h5py\n","import scipy.io as sio\n","from os.path import abspath\n","from sklearn.model_selection import train_test_split\n","from skimage import io\n","import time\n","import os\n","import shutil\n","import csv\n","from PIL import Image \n","from PIL.TiffTags import TAGS\n","from scipy.ndimage import gaussian_filter\n","import math\n","from astropy.visualization import simple_norm\n","from sys import getsizeof\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","from datetime import datetime\n","\n","\n","# For sliders and dropdown menu, progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","from tqdm import tqdm\n","\n","# For Multi-threading in simulation\n","from numba import njit, prange\n","\n","\n","\n","# define a function that projects and rescales an image to the range [0,1]\n","def project_01(im):\n"," im = np.squeeze(im)\n"," min_val = im.min()\n"," max_val = im.max()\n"," return (im - min_val)/(max_val - min_val)\n","\n","# normalize image given mean and std\n","def normalize_im(im, dmean, dstd):\n"," im = np.squeeze(im)\n"," im_norm = np.zeros(im.shape,dtype=np.float32)\n"," im_norm = (im - dmean)/dstd\n"," return im_norm\n","\n","# Define the loss history recorder\n","class LossHistory(Callback):\n"," def on_train_begin(self, logs={}):\n"," self.losses = []\n","\n"," def on_batch_end(self, batch, logs={}):\n"," self.losses.append(logs.get('loss'))\n"," \n","# Define a matlab like gaussian 2D filter\n","def matlab_style_gauss2D(shape=(7,7),sigma=1):\n"," \"\"\" \n"," 2D gaussian filter - should give the same result as:\n"," MATLAB's fspecial('gaussian',[shape],[sigma]) \n"," \"\"\"\n"," m,n = [(ss-1.)/2. for ss in shape]\n"," y,x = np.ogrid[-m:m+1,-n:n+1]\n"," h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n"," h.astype(dtype=K.floatx())\n"," h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n"," sumh = h.sum()\n"," if sumh != 0:\n"," h /= sumh\n"," h = h*2.0\n"," h = h.astype('float32')\n"," return h\n","\n","# Expand the filter dimensions\n","psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)\n","gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])\n","\n","# Combined MSE + L1 loss\n","def L1L2loss(input_shape):\n"," def bump_mse(heatmap_true, spikes_pred):\n","\n"," # generate the heatmap corresponding to the predicted spikes\n"," heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')\n","\n"," # heatmaps MSE\n"," loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)\n","\n"," # l1 on the predicted spikes\n"," loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))\n"," return loss_heatmaps + loss_spikes\n"," return bump_mse\n","\n","# Define the concatenated conv2, batch normalization, and relu block\n","def conv_bn_relu(nb_filter, rk, ck, name):\n"," def f(input):\n"," conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\\\n"," padding=\"same\", use_bias=False,\\\n"," kernel_initializer=\"Orthogonal\",name='conv-'+name)(input)\n"," conv_norm = BatchNormalization(name='BN-'+name)(conv)\n"," conv_norm_relu = Activation(activation = \"relu\",name='Relu-'+name)(conv_norm)\n"," return conv_norm_relu\n"," return f\n","\n","# Define the model architechture\n","def CNN(input,names):\n"," Features1 = conv_bn_relu(32,3,3,names+'F1')(input)\n"," pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)\n"," Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)\n"," pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)\n"," Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)\n"," pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)\n"," Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)\n"," up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)\n"," Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)\n"," up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)\n"," Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)\n"," up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)\n"," Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)\n"," return Features7\n","\n","# Define the Model building for an arbitrary input size\n","def buildModel(input_dim, initial_learning_rate = 0.001):\n"," input_ = Input (shape = (input_dim))\n"," act_ = CNN (input_,'CNN')\n"," density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding=\"same\",\\\n"," activation=\"linear\", use_bias = False,\\\n"," kernel_initializer=\"Orthogonal\",name='Prediction')(act_)\n"," model = Model (inputs= input_, outputs=density_pred)\n"," opt = optimizers.Adam(lr = initial_learning_rate)\n"," model.compile(optimizer=opt, loss = L1L2loss(input_dim))\n"," return model\n","\n","\n","# define a function that trains a model for a given data SNR and density\n","def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size, upsampling_factor=8, validation_split = 0.3, initial_learning_rate = 0.001, pretrained_model_path = '', L2_weighting_factor = 100):\n"," \n"," \"\"\"\n"," This function trains a CNN model on the desired training set, given the \n"," upsampled training images and labels generated in MATLAB.\n"," \n"," # Inputs\n"," # TO UPDATE ----------\n","\n"," # Outputs\n"," function saves the weights of the trained model to a hdf5, and the \n"," normalization factors to a mat file. These will be loaded later for testing \n"," the model in test_model. \n"," \"\"\"\n"," \n"," # for reproducibility\n"," np.random.seed(123)\n","\n"," X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size = validation_split, random_state=42)\n"," print('Number of training examples: %d' % X_train.shape[0])\n"," print('Number of validation examples: %d' % X_test.shape[0])\n"," \n"," # Setting type\n"," X_train = X_train.astype('float32')\n"," X_test = X_test.astype('float32')\n"," y_train = y_train.astype('float32')\n"," y_test = y_test.astype('float32')\n","\n"," \n"," #===================== Training set normalization ==========================\n"," # normalize training images to be in the range [0,1] and calculate the \n"," # training set mean and std\n"," mean_train = np.zeros(X_train.shape[0],dtype=np.float32)\n"," std_train = np.zeros(X_train.shape[0], dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train[i, :, :] = project_01(X_train[i, :, :])\n"," mean_train[i] = X_train[i, :, :].mean()\n"," std_train[i] = X_train[i, :, :].std()\n","\n"," # resulting normalized training images\n"," mean_val_train = mean_train.mean()\n"," std_val_train = std_train.mean()\n"," X_train_norm = np.zeros(X_train.shape, dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)\n"," \n"," # patch size\n"," psize = X_train_norm.shape[1]\n","\n"," # Reshaping\n"," X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)\n","\n"," # ===================== Test set normalization ==========================\n"," # normalize test images to be in the range [0,1] and calculate the test set \n"," # mean and std\n"," mean_test = np.zeros(X_test.shape[0],dtype=np.float32)\n"," std_test = np.zeros(X_test.shape[0], dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test[i, :, :] = project_01(X_test[i, :, :])\n"," mean_test[i] = X_test[i, :, :].mean()\n"," std_test[i] = X_test[i, :, :].std()\n","\n"," # resulting normalized test images\n"," mean_val_test = mean_test.mean()\n"," std_val_test = std_test.mean()\n"," X_test_norm = np.zeros(X_test.shape, dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)\n"," \n"," # Reshaping\n"," X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)\n","\n"," # Reshaping labels\n"," Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)\n"," Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)\n","\n"," # Save datasets to a matfile to open later in matlab\n"," mdict = {\"mean_test\": mean_val_test, \"std_test\": std_val_test, \"upsampling_factor\": upsampling_factor, \"Normalization factor\": L2_weighting_factor}\n"," sio.savemat(os.path.join(modelPath,\"model_metadata.mat\"), mdict)\n","\n","\n"," # Set the dimensions ordering according to tensorflow consensous\n"," # K.set_image_dim_ordering('tf')\n"," K.set_image_data_format('channels_last')\n","\n"," # Save the model weights after each epoch if the validation loss decreased\n"," checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,\"weights_best.hdf5\"), verbose=1,\n"," save_best_only=True)\n","\n"," # Change learning when loss reaches a plataeu\n"," change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)\n"," \n"," # Model building and complitation\n"," model = buildModel((psize, psize, 1), initial_learning_rate = initial_learning_rate)\n"," model.summary()\n","\n"," # Load pretrained model\n"," if not pretrained_model_path:\n"," print('Using random initial model weights.')\n"," else:\n"," print('Loading model weights from '+pretrained_model_path)\n"," model.load_weights(pretrained_model_path)\n"," \n"," # Create an image data generator for real time data augmentation\n"," datagen = ImageDataGenerator(\n"," featurewise_center=False, # set input mean to 0 over the dataset\n"," samplewise_center=False, # set each sample mean to 0\n"," featurewise_std_normalization=False, # divide inputs by std of the dataset\n"," samplewise_std_normalization=False, # divide each input by its std\n"," zca_whitening=False, # apply ZCA whitening\n"," rotation_range=0., # randomly rotate images in the range (degrees, 0 to 180)\n"," width_shift_range=0., # randomly shift images horizontally (fraction of total width)\n"," height_shift_range=0., # randomly shift images vertically (fraction of total height)\n"," zoom_range=0.,\n"," shear_range=0.,\n"," horizontal_flip=False, # randomly flip images\n"," vertical_flip=False, # randomly flip images\n"," fill_mode='constant',\n"," data_format=K.image_data_format())\n","\n"," # Fit the image generator on the training data\n"," datagen.fit(X_train_norm)\n"," \n"," # loss history recorder\n"," history = LossHistory()\n","\n"," # Inform user training begun\n"," print('-------------------------------')\n"," print('Training model...')\n","\n"," # Fit model on the batches generated by datagen.flow()\n"," train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size), \n"," steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1, \n"," validation_data=(X_test_norm, Y_test), \n"," callbacks=[history, checkpointer, change_lr]) \n","\n"," # Inform user training ended\n"," print('-------------------------------')\n"," print('Training Complete!')\n"," \n"," # Save the last model\n"," model.save(os.path.join(modelPath, 'weights_last.hdf5'))\n","\n"," # convert the history.history dict to a pandas DataFrame: \n"," lossData = pd.DataFrame(train_history.history) \n","\n"," if os.path.exists(os.path.join(modelPath,\"Quality Control\")):\n"," shutil.rmtree(os.path.join(modelPath,\"Quality Control\"))\n","\n"," os.makedirs(os.path.join(modelPath,\"Quality Control\"))\n","\n"," # The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(modelPath,\"Quality Control/training_evaluation.csv\")\n"," with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss','learning rate'])\n"," for i in range(len(train_history.history['loss'])):\n"," writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i], train_history.history['lr'][i]])\n","\n"," return\n","\n","\n","# Normalization functions from Martin Weigert used in CARE\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","# Multi-threaded Erf-based image construction\n","@njit(parallel=True)\n","def FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," erfImage = np.zeros((w, h))\n"," for ij in prange(w*h):\n"," j = int(ij/w)\n"," i = ij - j*w\n"," for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):\n"," # Don't bother if the emitter has photons <= 0 or if Sigma <= 0\n"," if (sigma > 0) and (photon > 0):\n"," S = sigma*math.sqrt(2)\n"," x = i*pixel_size - xc\n"," y = j*pixel_size - yc\n"," # Don't bother if the emitter is further than 4 sigma from the centre of the pixel\n"," if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:\n"," ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)\n"," ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)\n"," erfImage[j][i] += 0.25*photon*ErfX*ErfY\n"," return erfImage\n","\n","\n","@njit(parallel=True)\n","def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," locImage = np.zeros((image_size[0],image_size[1]) )\n"," n_locs = len(xc_array)\n","\n"," for e in prange(n_locs):\n"," locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1\n","\n"," return locImage\n","\n","\n","\n","def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n"," with Image.open(TIFFpath) as img:\n"," meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n","\n","\n"," # TIFF tags\n"," # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n"," # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n"," ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n"," width = meta_dict['ImageWidth'][0]\n"," height = meta_dict['ImageLength'][0]\n","\n"," xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n","\n"," if len(xResolution) == 1:\n"," xResolution = xResolution[0]\n"," elif len(xResolution) == 2:\n"," xResolution = xResolution[0]/xResolution[1]\n"," else:\n"," print('Image resolution not defined.')\n"," xResolution = 1\n","\n"," if ResolutionUnit == 2:\n"," # Units given are in inches\n"," pixel_size = 0.025*1e9/xResolution\n"," elif ResolutionUnit == 3:\n"," # Units given are in cm\n"," pixel_size = 0.01*1e9/xResolution\n"," else: \n"," # ResolutionUnit is therefore 1\n"," print('Resolution unit not defined. Assuming: um')\n"," pixel_size = 1e3/xResolution\n","\n"," if display:\n"," print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n"," print('Image size: '+str(width)+'x'+str(height))\n"," \n"," return (pixel_size, width, height)\n","\n","\n","def saveAsTIF(path, filename, array, pixel_size):\n"," \"\"\"\n"," Image saving using PIL to save as .tif format\n"," # Input \n"," path - path where it will be saved\n"," filename - name of the file to save (no extension)\n"," array - numpy array conatining the data at the required format\n"," pixel_size - physical size of pixels in nanometers (identical for x and y)\n"," \"\"\"\n","\n"," # print('Data type: '+str(array.dtype))\n"," if (array.dtype == np.uint16):\n"," mode = 'I;16'\n"," elif (array.dtype == np.uint32):\n"," mode = 'I'\n"," else:\n"," mode = 'F'\n","\n"," # Rounding the pixel size to the nearest number that divides exactly 1cm.\n"," # Resolution needs to be a rational number --> see TIFF format\n"," # pixel_size = 10000/(round(10000/pixel_size))\n","\n"," if len(array.shape) == 2:\n"," im = Image.fromarray(array)\n"," im.save(os.path.join(path, filename+'.tif'),\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n","\n"," elif len(array.shape) == 3:\n"," imlist = []\n"," for frame in array:\n"," imlist.append(Image.fromarray(frame))\n","\n"," imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,\n"," append_images=imlist[1:],\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n"," return\n","\n","\n","\n","\n","class Maximafinder(Layer):\n"," def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):\n"," super(Maximafinder, self).__init__(**kwargs)\n"," self.thresh = tf.constant(thresh, dtype=tf.float32)\n"," self.nhood = neighborhood_size\n"," self.use_local_avg = use_local_avg\n","\n"," def build(self, input_shape):\n"," if self.use_local_avg is True:\n"," self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n","\n"," def call(self, inputs):\n","\n"," # local maxima positions\n"," max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)\n"," cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)\n"," indices = tf.where(cond)\n"," bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]\n"," confidence = tf.gather_nd(inputs, indices)\n","\n"," # local CoG estimator\n"," if self.use_local_avg:\n"," x_image = K.conv2d(inputs, self.kernel_x, padding='same')\n"," y_image = K.conv2d(inputs, self.kernel_y, padding='same')\n"," sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')\n"," confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)\n"," x_local = tf.math.divide(tf.gather_nd(x_image, indices),tf.gather_nd(sum_image, indices))\n"," y_local = tf.math.divide(tf.gather_nd(y_image, indices),tf.gather_nd(sum_image, indices))\n"," xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)\n"," else:\n"," xind = tf.cast(xind, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32)\n"," \n"," return bind, xind, yind, confidence\n","\n"," def get_config(self):\n","\n"," # Implement get_config to enable serialization. This is optional.\n"," base_config = super(Maximafinder, self).get_config()\n"," config = {}\n"," return dict(list(base_config.items()) + list(config.items()))\n","\n","\n","\n","# ------------------------------- Prediction with postprocessing function-------------------------------\n","def batchFramePredictionLocalization(dataPath, filename, modelPath, savePath, batch_size=1, thresh=0.1, neighborhood_size=3, use_local_avg = False, pixel_size = None):\n"," \"\"\"\n"," This function tests a trained model on the desired test set, given the \n"," tiff stack of test images, learned weights, and normalization factors.\n"," \n"," # Inputs\n"," dataPath - the path to the folder containing the tiff stack(s) to run prediction on \n"," filename - the name of the file to process\n"," modelPath - the path to the folder containing the weights file and the mean and standard deviation file generated in train_model\n"," savePath - the path to the folder where to save the prediction\n"," batch_size. - the number of frames to predict on for each iteration\n"," thresh - threshoold percentage from the maximum of the gaussian scaling\n"," neighborhood_size - the size of the neighborhood for local maxima finding\n"," use_local_average - Boolean whether to perform local averaging or not\n"," \"\"\"\n"," \n"," # load mean and std\n"," matfile = sio.loadmat(os.path.join(modelPath,'model_metadata.mat'))\n"," test_mean = np.array(matfile['mean_test'])\n"," test_std = np.array(matfile['std_test']) \n"," upsampling_factor = np.array(matfile['upsampling_factor'])\n"," upsampling_factor = upsampling_factor.item() # convert to scalar\n"," L2_weighting_factor = np.array(matfile['Normalization factor'])\n"," L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n","\n"," # Read in the raw file\n"," Images = io.imread(os.path.join(dataPath, filename))\n"," if pixel_size == None:\n"," pixel_size, _, _ = getPixelSizeTIFFmetadata(os.path.join(dataPath, filename), display=True)\n"," pixel_size_hr = pixel_size/upsampling_factor\n","\n"," # get dataset dimensions\n"," (nFrames, M, N) = Images.shape\n"," print('Input image is '+str(N)+'x'+str(M)+' with '+str(nFrames)+' frames.')\n","\n"," # Build the model for a bigger image\n"," model = buildModel((upsampling_factor*M, upsampling_factor*N, 1))\n","\n"," # Load the trained weights\n"," model.load_weights(os.path.join(modelPath,'weights_best.hdf5'))\n","\n"," # add a post-processing module\n"," max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, use_local_avg)\n","\n"," # Initialise the results: lists will be used to collect all the localizations\n"," frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []\n","\n"," # Initialise the results\n"," Prediction = np.zeros((M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n"," Widefield = np.zeros((M, N), dtype=np.float32)\n","\n"," # run model in batches\n"," n_batches = math.ceil(nFrames/batch_size)\n"," for b in tqdm(range(n_batches)):\n","\n"," nF = min(batch_size, nFrames - b*batch_size)\n"," Images_norm = np.zeros((nF, M, N),dtype=np.float32)\n"," Images_upsampled = np.zeros((nF, M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n","\n"," # Upsampling using a simple nearest neighbor interp and calculating - MULTI-THREAD this?\n"," for f in range(nF):\n"," Images_norm[f,:,:] = project_01(Images[b*batch_size+f,:,:])\n"," Images_norm[f,:,:] = normalize_im(Images_norm[f,:,:], test_mean, test_std)\n"," Images_upsampled[f,:,:] = np.kron(Images_norm[f,:,:], np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield += Images[b*batch_size+f,:,:]\n","\n"," # Reshaping\n"," Images_upsampled = np.expand_dims(Images_upsampled,axis=3)\n","\n"," # Run prediction and local amxima finding\n"," predicted_density = model.predict_on_batch(Images_upsampled)\n"," predicted_density[predicted_density < 0] = 0\n"," Prediction += predicted_density.sum(axis = 3).sum(axis = 0)\n","\n"," bind, xind, yind, confidence = max_layer(predicted_density)\n"," \n"," # normalizing the confidence by the L2_weighting_factor\n"," confidence /= L2_weighting_factor \n","\n"," # turn indices to nms and append to the results\n"," xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n"," frmind = (bind.numpy() + b*batch_size + 1).tolist()\n"," xind = xind.numpy().tolist()\n"," yind = yind.numpy().tolist()\n"," confidence = confidence.numpy().tolist()\n"," frame_number_list += frmind\n"," x_nm_list += xind\n"," y_nm_list += yind\n"," confidence_au_list += confidence\n","\n"," # Open and create the csv file that will contain all the localizations\n"," if use_local_avg:\n"," ext = '_avg'\n"," else:\n"," ext = '_max'\n"," with open(os.path.join(savePath, 'Localizations_' + os.path.splitext(filename)[0] + ext + '.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n"," locs = list(zip(frame_number_list, x_nm_list, y_nm_list, confidence_au_list))\n"," writer.writerows(locs)\n","\n"," # Save the prediction and widefield image\n"," Widefield = np.kron(Widefield, np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield = np.float32(Widefield)\n","\n"," # io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'), Prediction)\n"," # io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'), Widefield)\n","\n"," saveAsTIF(savePath, 'Predicted_'+os.path.splitext(filename)[0], Prediction, pixel_size_hr)\n"," saveAsTIF(savePath, 'Widefield_'+os.path.splitext(filename)[0], Widefield, pixel_size_hr)\n","\n","\n"," return\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n","\n","\n","\n","def list_files(directory, extension):\n"," return (f for f in os.listdir(directory) if f.endswith('.' + extension))\n","\n","\n","# @njit(parallel=True)\n","def subPixelMaxLocalization(array, method = 'CoM', patch_size = 3):\n"," xMaxInd, yMaxInd = np.unravel_index(array.argmax(), array.shape, order='C')\n"," centralPatch = XC[(xMaxInd-patch_size):(xMaxInd+patch_size+1),(yMaxInd-patch_size):(yMaxInd+patch_size+1)]\n","\n"," if (method == 'MAX'):\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n","\n"," elif (method == 'CoM'):\n"," x0 = 0\n"," y0 = 0\n"," S = 0\n"," for xy in range(patch_size*patch_size):\n"," y = math.floor(xy/patch_size)\n"," x = xy - y*patch_size\n"," x0 += x*array[x,y]\n"," y0 += y*array[x,y]\n"," S = array[x,y]\n"," \n"," x0 = x0/S - patch_size/2 + xMaxInd\n"," y0 = y0/S - patch_size/2 + yMaxInd\n"," \n"," elif (method == 'Radiality'):\n"," # Not implemented yet\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n"," \n"," return (x0, y0)\n","\n","\n","@njit(parallel=True)\n","def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):\n"," n_locs = xc_array.shape[0]\n"," xc_array_Corr = np.empty(n_locs)\n"," yc_array_Corr = np.empty(n_locs)\n"," \n"," for loc in prange(n_locs):\n"," xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc]]\n"," yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc]]\n","\n"," return (xc_array_Corr, yc_array_Corr)\n","\n","\n","print('--------------------------------')\n","print('DeepSTORM installation complete.')\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","# Exporting requirements.txt for local run\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vu8f5NGJkJos"},"source":["\n","# **3. Generate patches for training**\n","---\n","\n","For Deep-STORM the training data can be obtained in two ways:\n","* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)\n","* Directly simulated in this notebook (**using Section 3.1.b**)\n"]},{"cell_type":"markdown","metadata":{"id":"WSV8xnlynp0l"},"source":["## **3.1.a Load training data**\n","---\n","\n","Here you can load your simulated data along with its corresponding localization file.\n","* The `pixel_size` is defined in nanometer (nm). "]},{"cell_type":"code","metadata":{"id":"CT6SNcfNg6j0","cellView":"form"},"source":["#@markdown ##Load raw data\n","\n","load_raw_data = True\n","\n","# Get user input\n","ImageData_path = \"\" #@param {type:\"string\"}\n","LocalizationData_path = \"\" #@param {type: \"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size,_,_ = getPixelSizeTIFFmetadata(ImageData_path, True)\n","\n","# load the tiff data\n","Images = io.imread(ImageData_path)\n","# get dataset dimensions\n","if len(Images.shape) == 3:\n"," (number_of_frames, M, N) = Images.shape\n","elif len(Images.shape) == 2:\n"," (M, N) = Images.shape\n"," number_of_frames = 1\n","print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')\n","\n","# Interactive display of the stack\n","def scroll_in_time(frame):\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source at frame = ' + str(frame))\n"," plt.axis('off');\n","\n","if number_of_frames > 1:\n"," interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","else:\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images, interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source')\n"," plt.axis('off');\n","\n","# Load the localization file and display the first\n","LocData = pd.read_csv(LocalizationData_path, index_col=0)\n","LocData.tail()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9xE5GeYiks9"},"source":["## **3.1.b Simulate training data**\n","---\n","This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view. \n","The assumptions are as follows:\n","\n","* Gaussian Point Spread Function (PSF) with standard deviation defined by `Sigma`. The nominal value of `sigma` can be evaluated using `sigma = 0.21 x Lambda / NA`. \n","* Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.\n","* The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC\n","* The `emitter_density` is defined as the number of emitters / um^2 on any given frame. Variability in the emitter density can be applied by adjusting `emitter_density_std`. The latter parameter represents the standard deviation of the normal distribution that the density is drawn from for each individual frame. `emitter_density` **is defined in number of emitters / um^2**.\n","* The `n_photons` and `sigma` can additionally include some Gaussian variability by setting `n_photons_std` and `sigma_std`.\n","\n","Important note:\n","- All dimensions are in nanometer (e.g. `FOV_size` = 6400 represents a field of view of 6.4 um x 6.4 um).\n","\n"]},{"cell_type":"code","metadata":{"id":"sQyLXpEhitsg","cellView":"form"},"source":["load_raw_data = False\n","\n","# ---------------------------- User input ----------------------------\n","#@markdown Run the simulation\n","#@markdown --- \n","#@markdown Camera settings: \n","FOV_size = 6400#@param {type:\"number\"}\n","pixel_size = 100#@param {type:\"number\"}\n","ADC_per_photon_conversion = 1 #@param {type:\"number\"}\n","ReadOutNoise_ADC = 4.5#@param {type:\"number\"}\n","ADC_offset = 50#@param {type:\"number\"}\n","\n","#@markdown Acquisition settings: \n","emitter_density = 6#@param {type:\"number\"}\n","emitter_density_std = 0#@param {type:\"number\"}\n","\n","number_of_frames = 20#@param {type:\"integer\"}\n","\n","sigma = 110 #@param {type:\"number\"}\n","sigma_std = 5 #@param {type:\"number\"}\n","# NA = 1.1 #@param {type:\"number\"}\n","# wavelength = 800#@param {type:\"number\"}\n","# wavelength_std = 150#@param {type:\"number\"}\n","n_photons = 2250#@param {type:\"number\"}\n","n_photons_std = 250#@param {type:\"number\"}\n","\n","\n","# ---------------------------- Variable initialisation ----------------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","print('-----------------------------------------------------------')\n","n_molecules = emitter_density*FOV_size*FOV_size/10**6\n","n_molecules_std = emitter_density_std*FOV_size*FOV_size/10**6\n","print('Number of molecules / FOV: '+str(round(n_molecules,2))+' +/- '+str((round(n_molecules_std,2))))\n","\n","# sigma = 0.21*wavelength/NA\n","# sigma_std = 0.21*wavelength_std/NA\n","# print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')\n","\n","M = N = round(FOV_size/pixel_size)\n","FOV_size = M*pixel_size\n","print('Final image size: '+str(M)+'x'+str(M)+' ('+str(round(FOV_size/1000, 3))+'um x'+str(round(FOV_size/1000,3))+' um)')\n","\n","np.random.seed(1)\n","display_upsampling = 8 # used to display the loc map here\n","NoiseFreeImages = np.zeros((number_of_frames, M, M))\n","locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))\n","\n","frames = []\n","all_xloc = []\n","all_yloc = []\n","all_photons = []\n","all_sigmas = []\n","\n","# ---------------------------- Main simulation loop ----------------------------\n","print('-----------------------------------------------------------')\n","for f in tqdm(range(number_of_frames)):\n"," \n"," # Define the coordinates of emitters by randomly distributing them across the FOV\n"," n_mol = int(max(round(np.random.normal(n_molecules, n_molecules_std, size=1)[0]), 0))\n"," x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," photon_array = np.random.normal(n_photons, n_photons_std, size=n_mol)\n"," sigma_array = np.random.normal(sigma, sigma_std, size=n_mol)\n"," # x_c = np.linspace(0,3000,5)\n"," # y_c = np.linspace(0,3000,5)\n","\n"," all_xloc += x_c.tolist()\n"," all_yloc += y_c.tolist()\n"," frames += ((f+1)*np.ones(x_c.shape[0])).tolist()\n"," all_photons += photon_array.tolist()\n"," all_sigmas += sigma_array.tolist()\n","\n"," locImage[f] = FromLoc2Image_SimpleHistogram(x_c, y_c, image_size = (N*display_upsampling, M*display_upsampling), pixel_size = pixel_size/display_upsampling)\n","\n"," # # Get the approximated locations according to the grid pixel size\n"," # Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]\n"," # Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]\n","\n"," # # Build Localization image\n"," # for (r,c) in zip(Rhr_emitters, Chr_emitters):\n"," # locImage[f][r][c] += 1\n","\n"," NoiseFreeImages[f] = FromLoc2Image_Erf(x_c, y_c, photon_array, sigma_array, image_size = (M,M), pixel_size = pixel_size)\n","\n","\n","# ---------------------------- Create DataFrame fof localization file ----------------------------\n","# Table with localization info as dataframe output\n","LocData = pd.DataFrame()\n","LocData[\"frame\"] = frames\n","LocData[\"x [nm]\"] = all_xloc\n","LocData[\"y [nm]\"] = all_yloc\n","LocData[\"Photon #\"] = all_photons\n","LocData[\"Sigma [nm]\"] = all_sigmas\n","LocData.index += 1 # set indices to start at 1 and not 0 (same as ThunderSTORM)\n","\n","\n","# ---------------------------- Estimation of SNR ----------------------------\n","n_frames_for_SNR = 100\n","M_SNR = 10\n","x_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","y_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","photon_array = np.random.normal(n_photons, n_photons_std, size=n_frames_for_SNR)\n","sigma_array = np.random.normal(sigma, sigma_std, size=n_frames_for_SNR)\n","\n","SNR = np.zeros(n_frames_for_SNR)\n","for i in range(n_frames_for_SNR):\n"," SingleEmitterImage = FromLoc2Image_Erf(np.array([x_c[i]]), np.array([x_c[i]]), np.array([photon_array[i]]), np.array([sigma_array[i]]), (M_SNR, M_SNR), pixel_size)\n"," Signal_photon = np.max(SingleEmitterImage)\n"," Noise_photon = math.sqrt((ReadOutNoise_ADC/ADC_per_photon_conversion)**2 + Signal_photon)\n"," SNR[i] = Signal_photon/Noise_photon\n","\n","print('SNR: '+str(round(np.mean(SNR),2))+' +/- '+str(round(np.std(SNR),2)))\n","# ---------------------------- ----------------------------\n","\n","\n","# Table with info\n","simParameters = pd.DataFrame()\n","simParameters[\"FOV size (nm)\"] = [FOV_size]\n","simParameters[\"Pixel size (nm)\"] = [pixel_size]\n","simParameters[\"ADC/photon\"] = [ADC_per_photon_conversion]\n","simParameters[\"Read-out noise (ADC)\"] = [ReadOutNoise_ADC]\n","simParameters[\"Constant offset (ADC)\"] = [ADC_offset]\n","\n","simParameters[\"Emitter density (emitters/um^2)\"] = [emitter_density]\n","simParameters[\"STD of emitter density (emitters/um^2)\"] = [emitter_density_std]\n","simParameters[\"Number of frames\"] = [number_of_frames]\n","# simParameters[\"NA\"] = [NA]\n","# simParameters[\"Wavelength (nm)\"] = [wavelength]\n","# simParameters[\"STD of wavelength (nm)\"] = [wavelength_std]\n","simParameters[\"Sigma (nm))\"] = [sigma]\n","simParameters[\"STD of Sigma (nm))\"] = [sigma_std]\n","simParameters[\"Number of photons\"] = [n_photons]\n","simParameters[\"STD of number of photons\"] = [n_photons_std]\n","simParameters[\"SNR\"] = [np.mean(SNR)]\n","simParameters[\"STD of SNR\"] = [np.std(SNR)]\n","\n","\n","# ---------------------------- Finish simulation ----------------------------\n","# Calculating the noisy image\n","Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset\n","Images[Images <= 0] = 0\n","\n","# Convert to 16-bit or 32-bits integers\n","if Images.max() < (2**16-1):\n"," Images = Images.astype(np.uint16)\n","else:\n"," Images = Images.astype(np.uint32)\n","\n","\n","# ---------------------------- Display ----------------------------\n","# Displaying the time elapsed for simulation\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n","\n","\n","# Interactively display the results using Widgets\n","def scroll_in_time(frame):\n"," f = plt.figure(figsize=(18,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)\n"," plt.title('Localization image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noise-free simulation')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noisy simulation')\n"," plt.axis('off');\n","\n","interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","\n","# Display the head of the dataframe with localizations\n","LocData.tail()\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pz7RfSuoeJeq","cellView":"form"},"source":["# @markdown ---\n","# @markdown #Play this cell to save the simulated stack\n","# @markdown ####Please select a path to the folder where to save the simulated data. It is not necessary to save the data to run the training, but keeping the simulated for your own record can be useful to check its validity.\n","Save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(Save_path):\n"," os.makedirs(Save_path)\n"," print('Folder created.')\n","else:\n"," print('Training data already exists in folder: Data overwritten.')\n","\n","saveAsTIF(Save_path, 'SimulatedDataset', Images, pixel_size)\n","# io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)\n","LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))\n","simParameters.to_csv(os.path.join(Save_path, 'SimulatedParameters.csv'))\n","print('Training dataset saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K_8e3kE-JhVY"},"source":["## **3.2. Generate training patches**\n","---\n","\n","Training patches need to be created from the training data generated above. \n","* The `patch_size` needs to give sufficient contextual information and for most cases a `patch_size` of 26 (corresponding to patches of 26x26 pixels) works fine. **DEFAULT: 26**\n","* The `upsampling_factor` defines the effective magnification of the final super-resolved image compared to the input image (this is called magnification in ThunderSTORM). This is used to generate the super-resolved patches as target dataset. Using an `upsampling_factor` of 16 will require the use of more memory and it may be necessary to decreae the `patch_size` to 16 for example. **DEFAULT: 8**\n","* The `num_patches_per_frame` defines the number of patches extracted from each frame generated in section 3.1. **DEFAULT: 500**\n","* The `min_number_of_emitters_per_patch` defines the minimum number of emitters that need to be present in the patch to be a valid patch. An empty patch does not contain useful information for the network to learn from. **DEFAULT: 7**\n","* The `max_num_patches` defines the maximum number of patches to generate. Fewer may be generated depending on how many pacthes are rejected and how many frames are available. **DEFAULT: 10000**\n","* The `gaussian_sigma` defines the Gaussian standard deviation (in magnified pixels) applied to generate the super-resolved target image. **DEFAULT: 1**\n","* The `L2_weighting_factor` is a normalization factor used in the loss function. It helps balancing the loss from the L2 norm. When using higher densities, this factor should be decreased and vice-versa. This factor can be autimatically calculated using an empiraical formula. **DEFAULT: 100**\n","\n"]},{"cell_type":"code","metadata":{"id":"AsNx5KzcFNvC","cellView":"form"},"source":["#@markdown ## **Provide patch parameters**\n","\n","\n","# -------------------- User input --------------------\n","patch_size = 26 #@param {type:\"integer\"}\n","upsampling_factor = 8 #@param [\"4\", \"8\", \"16\"] {type:\"raw\"}\n","num_patches_per_frame = 500#@param {type:\"integer\"}\n","min_number_of_emitters_per_patch = 7#@param {type:\"integer\"}\n","max_num_patches = 10000#@param {type:\"integer\"}\n","gaussian_sigma = 1#@param {type:\"integer\"}\n","\n","#@markdown Estimate the optimal normalization factor automatically?\n","Automatic_normalization = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, it will use the following value:\n","L2_weighting_factor = 100 #@param {type:\"number\"}\n","\n","\n","# -------------------- Prepare variables --------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# Initialize some parameters\n","pixel_size_hr = pixel_size/upsampling_factor # in nm\n","n_patches = min(number_of_frames*num_patches_per_frame, max_num_patches)\n","patch_size = patch_size*upsampling_factor\n","\n","# Dimensions of the high-res grid\n","Mhr = upsampling_factor*M # in pixels\n","Nhr = upsampling_factor*N # in pixels\n","\n","# Initialize the training patches and labels\n","patches = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","spikes = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","heatmaps = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","\n","# Run over all frames and construct the training examples\n","k = 1 # current patch count\n","skip_counter = 0 # number of dataset skipped due to low density\n","id_start = 0 # id position in LocData for current frame\n","print('Generating '+str(n_patches)+' patches of '+str(patch_size)+'x'+str(patch_size))\n","\n","n_locs = len(LocData.index)\n","print('Total number of localizations: '+str(n_locs))\n","density = n_locs/(M*N*number_of_frames*(0.001*pixel_size)**2)\n","print('Density: '+str(round(density,2))+' locs/um^2')\n","n_locs_per_patch = patch_size**2*density\n","\n","if Automatic_normalization:\n"," # This empirical formulae attempts to balance the loss L2 function between the background and the bright spikes\n"," # A value of 100 was originally chosen to balance L2 for a patch size of 2.6x2.6^2 0.1um pixel size and density of 3 (hence the 20.28), at upsampling_factor = 8\n"," L2_weighting_factor = 100/math.sqrt(min(n_locs_per_patch, min_number_of_emitters_per_patch)*8**2/(upsampling_factor**2*20.28))\n"," print('Normalization factor: '+str(round(L2_weighting_factor,2)))\n","\n","# -------------------- Patch generation loop --------------------\n","\n","print('-----------------------------------------------------------')\n","for (f, thisFrame) in enumerate(tqdm(Images)):\n","\n"," # Upsample the frame\n"," upsampledFrame = np.kron(thisFrame, np.ones((upsampling_factor,upsampling_factor)))\n"," # Read all the provided high-resolution locations for current frame\n"," DataFrame = LocData[LocData['frame'] == f+1].copy()\n","\n"," # Get the approximated locations according to the high-res grid pixel size\n"," Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," id_start += len(DataFrame.index)\n","\n"," # Build Localization image\n"," LocImage = np.zeros((Mhr,Nhr))\n"," LocImage[(Rhr_emitters, Chr_emitters)] = 1\n","\n"," # Here, there's a choice between the original Gaussian (classification approach) and using the erf function\n"," HeatMapImage = L2_weighting_factor*gaussian_filter(LocImage, gaussian_sigma) \n"," # HeatMapImage = L2_weighting_factor*FromLoc2Image_MultiThreaded(np.array(list(DataFrame['x [nm]'])), np.array(list(DataFrame['y [nm]'])), \n"," # np.ones(len(DataFrame.index)), pixel_size_hr*gaussian_sigma*np.ones(len(DataFrame.index)), \n"," # Mhr, pixel_size_hr)\n"," \n","\n"," # Generate random position for the top left corner of the patch\n"," xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)\n"," yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)\n","\n"," for c in range(len(xc)):\n"," if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:\n"," skip_counter += 1\n"," continue\n"," \n"," else:\n"," # Limit maximal number of training examples to 15k\n"," if k > max_num_patches:\n"," break\n"," else:\n"," # Assign the patches to the right part of the images\n"," patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," heatmaps[k-1] = HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," k += 1 # increment current patch count\n","\n","# Remove the empty data\n","patches = patches[:k-1]\n","spikes = spikes[:k-1]\n","heatmaps = heatmaps[:k-1]\n","n_patches = k-1\n","\n","# -------------------- Failsafe --------------------\n","# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM\n","if ((k-1) < 5000):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(bcolors.WARNING+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+bcolors.NORMAL)\n","\n","\n","\n","# -------------------- Displays --------------------\n","print('Number of patches skipped due to low density: '+str(skip_counter))\n","# dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB\n","# print('Size of patches: '+str(dataSize)+' MB')\n","print(str(n_patches)+' patches were generated.')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display patches interactively with a slider\n","def scroll_patches(patch):\n"," f = plt.figure(figsize=(16,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(patches[patch-1], interpolation='nearest', cmap='gray')\n"," plt.title('Raw data (frame #'+str(patch)+')')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(heatmaps[patch-1], interpolation='nearest')\n"," plt.title('Heat map')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(spikes[patch-1], interpolation='nearest')\n"," plt.title('Localization map')\n"," plt.axis('off');\n"," \n"," plt.savefig('/content/TrainingDataExample_DeepSTORM2D.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DSjXFMevK7Iz"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"hVeyKU0MdAPx"},"source":["## **4.1. Select your paths and parameters**\n","\n","---\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for ~100 epochs. Evaluate the performance after training (see 5). **Default value: 80**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. **If this value is set to 0**, by default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 30** \n","\n","**`initial_learning_rate`:** This parameter represents the initial value to be used as learning rate in the optimizer. **Default value: 0.001**"]},{"cell_type":"code","metadata":{"id":"oa5cDZ7f_PF6","cellView":"form"},"source":["#@markdown ###Path to training images and parameters\n","\n","model_path = \"\" #@param {type: \"string\"} \n","model_name = \"\" #@param {type: \"string\"} \n","number_of_epochs = 80#@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"integer\"}\n","\n","number_of_steps = 0#@param {type:\"integer\"}\n","percentage_validation = 30 #@param {type:\"number\"}\n","initial_learning_rate = 0.001 #@param {type:\"number\"}\n","\n","\n","percentage_validation /= 100\n","if number_of_steps == 0: \n"," number_of_steps = int((1-percentage_validation)*n_patches/batch_size)\n"," print('Number of steps: '+str(number_of_steps))\n","\n","# Pretrained model path initialised here so next cell does not need to be run\n","h5_file_path = ''\n","Use_pretrained_model = False\n","\n","if not ('patches' in locals()):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(WARNING+'!! WARNING: No patches were found in memory currently. !!')\n","\n","Save_path = os.path.join(model_path, model_name)\n","if os.path.exists(Save_path):\n"," print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","print('-----------------------------')\n","print('Training parameters set.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WIyEvQBWLp9n"},"source":["\n","## **4.2. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deep-STORM 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"oHL5g0w8LqR0","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.hdf5 pretrained model does not exist'+bcolors.NORMAL)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead.'+bcolors.NORMAL)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+bcolors.NORMAL)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print('No pretrained network will be used.')\n"," h5_file_path = ''\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OADNcie-LHxA"},"source":["## **4.4. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"qDgMu_mAK8US","cellView":"form"},"source":["#@markdown ##Start training\n","\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(Save_path):\n"," shutil.rmtree(Save_path)\n","\n","# Create the model folder!\n","os.makedirs(Save_path)\n","\n","# Let's go !\n","train_model(patches, heatmaps, Save_path, \n"," steps_per_epoch=number_of_steps, epochs=number_of_epochs, batch_size=batch_size,\n"," upsampling_factor = upsampling_factor,\n"," validation_split = percentage_validation,\n"," initial_learning_rate = initial_learning_rate, \n"," pretrained_model_path = h5_file_path,\n"," L2_weighting_factor = L2_weighting_factor)\n","\n","# # Show info about the GPU memory useage\n","# !nvidia-smi\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","\n","\n","\n","# -------------------------------------------\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Deep-STORM'\n","#model_name = 'little_CARE_test'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hours)+ \"hour(s) \"+str(minutes)+\"min(s) \"+str(round(seconds))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and method:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","if load_raw_data == True:\n"," shape = (M,N)\n","else:\n"," shape = (int(FOV_size/pixel_size),int(FOV_size/pixel_size))\n","#dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. The models was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(180, 5, txt = text, align='L')\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if load_raw_data==False:\n"," simul_text = 'The training dataset was created in the notebook using the following simulation settings:'\n"," pdf.cell(200, 5, txt=simul_text, align='L')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
SettingSimulated Value
FOV_size{0}
pixel_size{1}
ADC_per_photon_conversion{2}
ReadOutNoise_ADC{3}
ADC_offset{4}
emitter_density{5}
emitter_density_std{6}
number_of_frames{7}
sigma{8}
sigma_std{9}
n_photons{10}
n_photons_std{11}
\n"," \"\"\".format(FOV_size, pixel_size, ADC_per_photon_conversion, ReadOutNoise_ADC, ADC_offset, emitter_density, emitter_density_std, number_of_frames, sigma, sigma_std, n_photons, n_photons_std)\n"," pdf.write_html(html)\n","else:\n"," simul_text = 'The training dataset was simulated using ThunderSTORM and loaded into the notebook.'\n"," pdf.multi_cell(190, 5, txt=simul_text, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," #pdf.ln(1)\n"," #pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'ImageData_path', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = ImageData_path, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'LocalizationData_path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = LocalizationData_path, align = 'L')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'pixel_size:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = str(pixel_size), align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","# if Use_Default_Advanced_Parameters:\n","# pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used to generate patches:')\n","pdf.ln(1)\n","html = \"\"\"\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","\"\"\".format(str(patch_size)+'x'+str(patch_size), upsampling_factor, num_patches_per_frame, min_number_of_emitters_per_patch, max_num_patches, gaussian_sigma, Automatic_normalization, L2_weighting_factor)\n","pdf.write_html(html)\n","pdf.ln(3)\n","pdf.set_font('Arial', size=10)\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","
Patch ParameterValue
patch_size{0}
upsampling_factor{1}
num_patches_per_frame{2}
min_number_of_emitters_per_patch{3}
max_num_patches{4}
gaussian_sigma{5}
Automatic_normalization{6}
L2_weighting_factor{7}
\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
Training ParameterValue
number_of_epochs{0}
batch_size{1}
number_of_steps{2}
percentage_validation{3}
initial_learning_rate{4}
\n","\"\"\".format(number_of_epochs,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","# pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training Images', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_DeepSTORM2D.png').shape\n","pdf.image('/content/TrainingDataExample_DeepSTORM2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","# if Use_Data_augmentation:\n","# ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n","# pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","print('------------------------------')\n","print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"CHVTRjEOLRDH"},"source":["##**4.5. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"4N7-ShZpLhwr"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"JDRsm7uKoBa-","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter `model_name` (see section 4.1). Provide the name of this folder as `QC_model_path` . \n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(QC_model_path):\n"," print(\"The \"+os.path.basename(QC_model_path)+\" model will be evaluated\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Gw7KaHZUoHC4"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qUc-JMOcoGNZ","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'), bbox_inches='tight', pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"32eNQjFioQkY"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"QC_image_folder\" using teh corresponding localization data contained in \"QC_loc_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"dhlTnxC5lUZy","cellView":"form"},"source":["\n","# ------------------------ User input ------------------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","QC_image_folder = \"\" #@param{type:\"string\"}\n","QC_loc_folder = \"\" #@param{type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size_INPUT = None\n","else:\n"," pixel_size_INPUT = pixel_size\n","\n","\n","# ------------------------ QC analysis loop over provided dataset ------------------------\n","\n","savePath = os.path.join(QC_model_path, 'Quality Control')\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(os.path.join(savePath, os.path.basename(QC_model_path)+\"_QC_metrics.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"WF v. GT mSSIM\", \"Prediction v. GT NRMSE\",\"WF v. GT NRMSE\", \"Prediction v. GT PSNR\", \"WF v. GT PSNR\"])\n","\n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvWF_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvWF_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvWF_list = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n"," for (imageFilename, locFilename) in zip(list_files(QC_image_folder, 'tif'), list_files(QC_loc_folder, 'csv')):\n"," print('--------------')\n"," print(imageFilename)\n"," print(locFilename)\n","\n"," # Get the prediction\n"," batchFramePredictionLocalization(QC_image_folder, imageFilename, QC_model_path, savePath, pixel_size = pixel_size_INPUT)\n","\n"," # test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);\n"," thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))\n"," thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))\n","\n"," Mhr = thisPrediction.shape[0]\n"," Nhr = thisPrediction.shape[1]\n","\n"," if pixel_size_INPUT == None:\n"," pixel_size, N, M = getPixelSizeTIFFmetadata(os.path.join(QC_image_folder,imageFilename))\n","\n"," upsampling_factor = int(Mhr/M)\n"," print('Upsampling factor: '+str(upsampling_factor))\n"," pixel_size_hr = pixel_size/upsampling_factor # in nm\n","\n"," # Load the localization file and display the first\n"," LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)\n","\n"," x = np.array(list(LocData['x [nm]']))\n"," y = np.array(list(LocData['y [nm]']))\n"," locImage = FromLoc2Image_SimpleHistogram(x, y, image_size = (Mhr,Nhr), pixel_size = pixel_size_hr)\n","\n"," # Remove extension from filename\n"," imageFilename_no_extension = os.path.splitext(imageFilename)[0]\n","\n"," # io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)\n"," saveAsTIF(savePath, 'GT_image_'+imageFilename_no_extension, locImage, pixel_size_hr)\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)\n"," index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)\n","\n","\n"," # Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsPrediction_'+imageFilename_no_extension, img_SSIM_GTvsPrediction_32bit, pixel_size_hr)\n","\n","\n"," img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsWF_'+imageFilename_no_extension, img_SSIM_GTvsWF_32bit, pixel_size_hr)\n","\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsPrediction_'+imageFilename_no_extension, img_RSE_GTvsPrediction_32bit, pixel_size_hr)\n","\n"," img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsWF_'+imageFilename_no_extension, img_RSE_GTvsWF_32bit, pixel_size_hr)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)\n","\n"," writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])\n","\n"," # Collect values to display in dataframe output\n"," file_name_list.append(imageFilename)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvWF_list.append(index_SSIM_GTvsWF)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvWF_list.append(NRMSE_GTvsWF)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvWF_list.append(PSNR_GTvsWF)\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = file_name_list)\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list\n","pdResults[\"Wide-field v. GT mSSIM\"] = mSSIM_GvWF_list\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list\n","pdResults[\"Wide-field v. GT NRMSE\"] = NRMSE_GvWF_list\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list\n","pdResults[\"Wide-field v. GT PSNR\"] = PSNR_GvWF_list\n","\n","\n","# ------------------------ Display ------------------------\n","\n","print('--------------------------------------------')\n","@interact\n","def show_QC_results(file = list_files(QC_image_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,15))\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))\n"," plt.imshow(img_GT, norm = simple_norm(img_GT, percent = 99.5))\n"," plt.title('Target',fontsize=15)\n","\n"," # Wide-field\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield',fontsize=15)\n","\n"," #Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Prediction',fontsize=15)\n","\n"," #Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))\n"," imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Wide-field v. GT mSSIM\"],3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Prediction v. GT mSSIM\"],3)),fontsize=14)\n","\n"," #Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))\n"," imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Wide-field v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Wide-field v. GT PSNR\"],3)),fontsize=14)\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Prediction v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Prediction v. GT PSNR\"],3)),fontsize=14)\n"," plt.savefig(QC_model_path+'/Quality Control/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n","print('--------------------------------------------')\n","pdResults.head()\n","\n","\n","# -------------------------------------------------------------------\n","#Make a pdf summary of the QC results\n","\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Deep-STORM'\n","#model_name = os.path.basename(full_QC_model_path)\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+os.path.basename(QC_model_path)+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n","pdf.ln(1)\n","if os.path.exists(savePath+'/lossCurvePlots.png'):\n"," exp_size = io.imread(savePath+'/lossCurvePlots.png').shape\n"," pdf.image(savePath+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n","pdf.ln(2)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(savePath+'/QC_example_data.png').shape\n","pdf.image(savePath+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(savePath+'/'+os.path.basename(QC_model_path)+'_QC_metrics.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n","\n","\n","print('------------------------------')\n","print('QC PDF report exported as '+savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yTRou0izLjhd"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"eAf8aBDmWTx7"},"source":["## **6.1 Generate image prediction and localizations from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the found localizations csv.\n","\n","**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 4**\n","\n","**`threshold`:** This paramter determines threshold for local maxima finding. The value is expected to reside in the range **[0,1]**. A higher `threshold` will result in less localizations. **DEFAULT: 0.1**\n","\n","**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**\n","\n","**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**\n"]},{"cell_type":"code","metadata":{"id":"7qn06T_A0lxf","cellView":"form"},"source":["\n","# ------------------------------- User input -------------------------------\n","#@markdown ### Data parameters\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value (in nm):\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","#@markdown ### Model parameters\n","#@markdown Do you want to use the model you just trained?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, please provide path to the model folder below\n","prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Prediction parameters\n","batch_size = 4#@param {type:\"integer\"}\n","\n","#@markdown ### Post processing parameters\n","threshold = 0.1#@param {type:\"number\"}\n","neighborhood_size = 3#@param {type:\"integer\"}\n","#@markdown Do you want to locally average the model output with CoG estimator ?\n","use_local_average = True #@param {type:\"boolean\"}\n","\n","\n","if get_pixel_size_from_file:\n"," pixel_size = None\n","\n","if (Use_the_current_trained_model): \n"," prediction_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(prediction_model_path):\n"," print(\"The \"+os.path.basename(prediction_model_path)+\" model will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","# inform user whether local averaging is being used\n","if use_local_average == True: \n"," print('Using local averaging')\n","\n","if not os.path.exists(Result_folder):\n"," print('Result folder was created.')\n"," os.makedirs(Result_folder)\n","\n","\n","# ------------------------------- Run predictions -------------------------------\n","\n","start = time.time()\n","#%% This script tests the trained fully convolutional network based on the \n","# saved training weights, and normalization created using train_model.\n","\n","if os.path.isdir(Data_folder): \n"," for filename in list_files(Data_folder, 'tif'):\n"," # run the testing/reconstruction process\n"," print(\"------------------------------------\")\n"," print(\"Running prediction on: \"+ filename)\n"," batchFramePredictionLocalization(Data_folder, filename, prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average,\n"," pixel_size = pixel_size)\n","\n","elif os.path.isfile(Data_folder):\n"," batchFramePredictionLocalization(os.path.dirname(Data_folder), os.path.basename(Data_folder), prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average, \n"," pixel_size = pixel_size)\n","\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------------------- Interactive display -------------------------------\n","\n","print('--------------------------------------------------------------------')\n","print('---------------------------- Previews ------------------------------')\n","print('--------------------------------------------------------------------')\n","\n","if os.path.isdir(Data_folder): \n"," @interact\n"," def show_QC_results(file = list_files(Data_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n","if os.path.isfile(Data_folder):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZekzexaPmzFZ"},"source":["## **6.2 Drift correction**\n","---\n","\n","The visualization above is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. The display is a preview without any drift correction applied. This section performs drift correction using cross-correlation between time bins to estimate the drift.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the image reconstructions used for the Drift Correction estmication (in **nm**). A smaller pixel size will be more precise but will take longer to compute. **DEFAULT: 20**\n","\n","**`number_of_bins`:** This parameter defines how many temporal bins are used across the full dataset. All localizations in each bins are used ot build an image. This image is used to find the drift with respect to the image obtained from the very first bin. A typical value would correspond to about 500 frames per bin. **DEFAULT: Total number of frames / 500**\n","\n","**`polynomial_fit_degree`:** The drift obtained for each temporal bins needs to be interpolated to every single frames. This is performed by polynomial fit, the degree of which is defined here. **DEFAULT: 4**\n","\n"," The drift-corrected localization data is automaticaly saved in the `save_path` folder."]},{"cell_type":"code","metadata":{"id":"hYtP_vh6mzUP","cellView":"form"},"source":["# @markdown ##Data parameters\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","# @markdown ##Drift correction parameters\n","visualization_pixel_size = 20#@param {type:\"number\"}\n","number_of_bins = 50#@param {type:\"integer\"}\n","polynomial_fit_degree = 4#@param {type:\"integer\"}\n","\n","# @markdown ##Saving parameters\n","save_path = '' #@param {type:\"string\"}\n","\n","\n","# Let's go !\n","start = time.time()\n","\n","# Get info from the raw file if selected\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","# Read the localizations in\n","LocData = pd.read_csv(Loc_file_path)\n","\n","# Calculate a few variables \n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","n_locs = len(LocData.index)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(n_locs))\n","\n","blocksize = math.ceil(nFrames/number_of_bins)\n","print('Number of frames per block: '+str(blocksize))\n","\n","blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()\n","xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n","# Preparing the Reference image\n","photon_array = np.ones(yc_array.shape[0])\n","sigma_array = np.ones(yc_array.shape[0])\n","ImageRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImagesRef = np.rot90(ImageRef, k=2)\n","\n","xDrift = np.zeros(number_of_bins)\n","yDrift = np.zeros(number_of_bins)\n","\n","filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","with open(os.path.join(save_path, filename_no_extension+\"_DriftCorrectionData.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"Block #\", \"x-drift [nm]\",\"y-drift [nm]\"])\n","\n"," for b in tqdm(range(number_of_bins)):\n","\n"," blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()\n"," xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n"," yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n"," photon_array = np.ones(yc_array.shape[0])\n"," sigma_array = np.ones(yc_array.shape[0])\n"," ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n"," XC = fftconvolve(ImagesRef, ImageBlock, mode = 'same')\n"," yDrift[b], xDrift[b] = subPixelMaxLocalization(XC, method = 'CoM')\n","\n"," # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)\n"," # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)\n"," writer.writerow([str(b), str((xDrift[b]-xDrift[0])*visualization_pixel_size), str((yDrift[b]-yDrift[0])*visualization_pixel_size)])\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","print('Fitting drift data...')\n","bin_number = np.arange(number_of_bins)*blocksize + blocksize/2\n","xDrift = (xDrift-xDrift[0])*visualization_pixel_size\n","yDrift = (yDrift-yDrift[0])*visualization_pixel_size\n","\n","xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)\n","yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)\n","\n","xDriftFit = np.poly1d(xDriftCoeff)\n","yDriftFit = np.poly1d(yDriftCoeff)\n","bins = np.arange(nFrames)\n","xDriftInterpolated = xDriftFit(bins)\n","yDriftInterpolated = yDriftFit(bins)\n","\n","\n","# ------------------ Displaying the image results ------------------\n","\n","plt.figure(figsize=(15,10))\n","plt.plot(bin_number,xDrift, 'r+', label='x-drift')\n","plt.plot(bin_number,yDrift, 'b+', label='y-drift')\n","plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')\n","plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')\n","plt.title('Cross-correlation estimated drift')\n","plt.ylabel('Drift [nm]')\n","plt.xlabel('Bin number')\n","plt.legend();\n","\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\", hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------ Actual drift correction -------------------\n","\n","print('Correcting localization data...')\n","xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)\n","frames = LocData['frame'].to_numpy(dtype=np.int32)\n","\n","\n","xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)\n","ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","\n","# ------------------ Displaying the imge results ------------------\n","plt.figure(figsize=(15,7.5))\n","# Raw\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(ImageRaw, norm = simple_norm(ImageRaw, percent = 99.5))\n","plt.title('Raw', fontsize=15);\n","# Corrected\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(ImageCorr, norm = simple_norm(ImageCorr, percent = 99.5))\n","plt.title('Corrected',fontsize=15);\n","\n","\n","# ------------------ Table with info -------------------\n","driftCorrectedLocData = pd.DataFrame()\n","driftCorrectedLocData['frame'] = frames\n","driftCorrectedLocData['x [nm]'] = xc_array_Corr\n","driftCorrectedLocData['y [nm]'] = yc_array_Corr\n","driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']\n","\n","driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))\n","print('-------------------------------')\n","print('Corrected localizations saved.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mzOuc-V7rB-r"},"source":["## **6.3 Visualization of the localizations**\n","---\n","\n","\n","The visualization in section 6.1 is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. This section performs visualization of the result by plotting the localizations as a simple histogram.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the final image reconstruction (in **nm**). **DEFAULT: 10**\n","\n","**`visualization_mode`:** This parameter defines what visualization method is used to visualize the final image. NOTES: The Integrated Gaussian can be quite slow. **DEFAULT: Simple histogram.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"876yIXnqq-nW","cellView":"form"},"source":["# @markdown ##Data parameters\n","Use_current_drift_corrected_localizations = True #@param {type:\"boolean\"}\n","# @markdown Otherwise provide a localization file path\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100#@param {type:\"number\"}\n","\n","# @markdown ##Visualization parameters\n","visualization_pixel_size = 10#@param {type:\"number\"}\n","visualization_mode = \"Simple histogram\" #@param [\"Simple histogram\", \"Integrated Gaussian (SLOW!)\"]\n","\n","if not Use_current_drift_corrected_localizations:\n"," filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","if Use_current_drift_corrected_localizations:\n"," LocData = driftCorrectedLocData\n","else:\n"," LocData = pd.read_csv(Loc_file_path)\n","\n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","\n","\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(len(LocData.index)))\n","\n","xc_array = LocData['x [nm]'].to_numpy()\n","yc_array = LocData['y [nm]'].to_numpy()\n","if (visualization_mode == 'Simple histogram'):\n"," locImage = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","elif (visualization_mode == 'Shifted histogram'):\n"," print(bcolors.WARNING+'Method not implemented yet!'+bcolors.NORMAL)\n"," locImage = np.zeros(image_size)\n","elif (visualization_mode == 'Integrated Gaussian (SLOW!)'):\n"," photon_array = np.ones(xc_array.shape)\n"," sigma_array = np.ones(xc_array.shape)\n"," locImage = FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display\n","plt.figure(figsize=(20,10))\n","plt.axis('off')\n","# plt.imshow(locImage, cmap='gray');\n","plt.imshow(locImage, norm = simple_norm(locImage, percent = 99.5));\n","\n","\n","LocData.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"PdOhWwMn1zIT","cellView":"form"},"source":["# @markdown ---\n","# @markdown #Play this cell to save the visualization\n","# @markdown ####Please select a path to the folder where to save the visualization.\n","save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(save_path):\n"," os.makedirs(save_path)\n"," print('Folder created.')\n","\n","saveAsTIF(save_path, filename_no_extension+'_Visualization', locImage, visualization_pixel_size)\n","print('Image saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1EszIF4Dkz_n"},"source":["## **6.4. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UgN-NooKk3nV"},"source":["\n","#**Thank you for using Deep-STORM 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Deep-STORM_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1kD3rjN5XX5C33cQuX1DVc_n89cMqNvS_","timestamp":1610633423190},{"file_id":"1w95RljMrg15FLDRnEJiLIEa-lW-jEjQS","timestamp":1602684895691},{"file_id":"169qcwQo-yw15PwoGatXAdBvjs4wt_foD","timestamp":1592147948265},{"file_id":"1gjRCgDORKi_GNBu4QnVCBkSWrfPtqL-E","timestamp":1588525976305},{"file_id":"1DFy6aCi1XAVdjA5KLRZirB2aMZkMFdv-","timestamp":1587998755430},{"file_id":"1NpzigQoXGy3GFdxh4_jvG1PnBfyrcpBs","timestamp":1587569988032},{"file_id":"1jdI540qAfMSQwjnMhoAFkGJH9EbHwNSf","timestamp":1587486196143}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"FpCtYevLHfl4"},"source":["# **Deep-STORM (2D)**\n","\n","---\n","\n","Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).\n","\n","Deep-STORM has **two key advantages**:\n","- SMLM reconstruction at high density of emitters\n","- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.\n","\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)\n","\n","And source code found in: https://github.com/EliasNehme/Deep-STORM\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"wyzTn3IcHq6Y"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"bEy4EBXHHyAX"},"source":["#**0. Before getting started**\n","---\n"," Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"E04mOlG_H5Tz"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"F_tjlGzsH-Dn"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"gn-LaaNNICqL","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","import tensorflow as tf\n","# if tf.__version__ != '2.2.0':\n","# !pip install tensorflow==2.2.0\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tnP7wM79IKW-"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"1R-7Fo34_gOd","cellView":"form"},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jRnQZWSZhArJ"},"source":["# **2. Install Deep-STORM and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"kSrZMo3X_NhO","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","#@markdown ##Install Deep-STORM and dependencies\n","\n","\n","# %% Model definition + helper functions\n","\n","!pip install fpdf\n","# Import keras modules and libraries\n","from tensorflow import keras\n","from tensorflow.keras.models import Model\n","from tensorflow.keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization, Layer\n","from tensorflow.keras.callbacks import Callback\n","from tensorflow.keras import backend as K\n","from tensorflow.keras import optimizers, losses\n","\n","from tensorflow.keras.preprocessing.image import ImageDataGenerator\n","from tensorflow.keras.callbacks import ModelCheckpoint\n","from tensorflow.keras.callbacks import ReduceLROnPlateau\n","from skimage.transform import warp\n","from skimage.transform import SimilarityTransform\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from scipy.signal import fftconvolve\n","\n","# Import common libraries\n","import tensorflow as tf\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import h5py\n","import scipy.io as sio\n","from os.path import abspath\n","from sklearn.model_selection import train_test_split\n","from skimage import io\n","import time\n","import os\n","import shutil\n","import csv\n","from PIL import Image \n","from PIL.TiffTags import TAGS\n","from scipy.ndimage import gaussian_filter\n","import math\n","from astropy.visualization import simple_norm\n","from sys import getsizeof\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","from datetime import datetime\n","\n","\n","# For sliders and dropdown menu, progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","from tqdm import tqdm\n","\n","# For Multi-threading in simulation\n","from numba import njit, prange\n","\n","\n","\n","# define a function that projects and rescales an image to the range [0,1]\n","def project_01(im):\n"," im = np.squeeze(im)\n"," min_val = im.min()\n"," max_val = im.max()\n"," return (im - min_val)/(max_val - min_val)\n","\n","# normalize image given mean and std\n","def normalize_im(im, dmean, dstd):\n"," im = np.squeeze(im)\n"," im_norm = np.zeros(im.shape,dtype=np.float32)\n"," im_norm = (im - dmean)/dstd\n"," return im_norm\n","\n","# Define the loss history recorder\n","class LossHistory(Callback):\n"," def on_train_begin(self, logs={}):\n"," self.losses = []\n","\n"," def on_batch_end(self, batch, logs={}):\n"," self.losses.append(logs.get('loss'))\n"," \n","# Define a matlab like gaussian 2D filter\n","def matlab_style_gauss2D(shape=(7,7),sigma=1):\n"," \"\"\" \n"," 2D gaussian filter - should give the same result as:\n"," MATLAB's fspecial('gaussian',[shape],[sigma]) \n"," \"\"\"\n"," m,n = [(ss-1.)/2. for ss in shape]\n"," y,x = np.ogrid[-m:m+1,-n:n+1]\n"," h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n"," h.astype(dtype=K.floatx())\n"," h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n"," sumh = h.sum()\n"," if sumh != 0:\n"," h /= sumh\n"," h = h*2.0\n"," h = h.astype('float32')\n"," return h\n","\n","# Expand the filter dimensions\n","psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)\n","gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])\n","\n","# Combined MSE + L1 loss\n","def L1L2loss(input_shape):\n"," def bump_mse(heatmap_true, spikes_pred):\n","\n"," # generate the heatmap corresponding to the predicted spikes\n"," heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')\n","\n"," # heatmaps MSE\n"," loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)\n","\n"," # l1 on the predicted spikes\n"," loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))\n"," return loss_heatmaps + loss_spikes\n"," return bump_mse\n","\n","# Define the concatenated conv2, batch normalization, and relu block\n","def conv_bn_relu(nb_filter, rk, ck, name):\n"," def f(input):\n"," conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\\\n"," padding=\"same\", use_bias=False,\\\n"," kernel_initializer=\"Orthogonal\",name='conv-'+name)(input)\n"," conv_norm = BatchNormalization(name='BN-'+name)(conv)\n"," conv_norm_relu = Activation(activation = \"relu\",name='Relu-'+name)(conv_norm)\n"," return conv_norm_relu\n"," return f\n","\n","# Define the model architechture\n","def CNN(input,names):\n"," Features1 = conv_bn_relu(32,3,3,names+'F1')(input)\n"," pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)\n"," Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)\n"," pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)\n"," Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)\n"," pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)\n"," Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)\n"," up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)\n"," Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)\n"," up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)\n"," Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)\n"," up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)\n"," Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)\n"," return Features7\n","\n","# Define the Model building for an arbitrary input size\n","def buildModel(input_dim, initial_learning_rate = 0.001):\n"," input_ = Input (shape = (input_dim))\n"," act_ = CNN (input_,'CNN')\n"," density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding=\"same\",\\\n"," activation=\"linear\", use_bias = False,\\\n"," kernel_initializer=\"Orthogonal\",name='Prediction')(act_)\n"," model = Model (inputs= input_, outputs=density_pred)\n"," opt = optimizers.Adam(lr = initial_learning_rate)\n"," model.compile(optimizer=opt, loss = L1L2loss(input_dim))\n"," return model\n","\n","\n","# define a function that trains a model for a given data SNR and density\n","def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size, upsampling_factor=8, validation_split = 0.3, initial_learning_rate = 0.001, pretrained_model_path = '', L2_weighting_factor = 100):\n"," \n"," \"\"\"\n"," This function trains a CNN model on the desired training set, given the \n"," upsampled training images and labels generated in MATLAB.\n"," \n"," # Inputs\n"," # TO UPDATE ----------\n","\n"," # Outputs\n"," function saves the weights of the trained model to a hdf5, and the \n"," normalization factors to a mat file. These will be loaded later for testing \n"," the model in test_model. \n"," \"\"\"\n"," \n"," # for reproducibility\n"," np.random.seed(123)\n","\n"," X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size = validation_split, random_state=42)\n"," print('Number of training examples: %d' % X_train.shape[0])\n"," print('Number of validation examples: %d' % X_test.shape[0])\n"," \n"," # Setting type\n"," X_train = X_train.astype('float32')\n"," X_test = X_test.astype('float32')\n"," y_train = y_train.astype('float32')\n"," y_test = y_test.astype('float32')\n","\n"," \n"," #===================== Training set normalization ==========================\n"," # normalize training images to be in the range [0,1] and calculate the \n"," # training set mean and std\n"," mean_train = np.zeros(X_train.shape[0],dtype=np.float32)\n"," std_train = np.zeros(X_train.shape[0], dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train[i, :, :] = project_01(X_train[i, :, :])\n"," mean_train[i] = X_train[i, :, :].mean()\n"," std_train[i] = X_train[i, :, :].std()\n","\n"," # resulting normalized training images\n"," mean_val_train = mean_train.mean()\n"," std_val_train = std_train.mean()\n"," X_train_norm = np.zeros(X_train.shape, dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)\n"," \n"," # patch size\n"," psize = X_train_norm.shape[1]\n","\n"," # Reshaping\n"," X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)\n","\n"," # ===================== Test set normalization ==========================\n"," # normalize test images to be in the range [0,1] and calculate the test set \n"," # mean and std\n"," mean_test = np.zeros(X_test.shape[0],dtype=np.float32)\n"," std_test = np.zeros(X_test.shape[0], dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test[i, :, :] = project_01(X_test[i, :, :])\n"," mean_test[i] = X_test[i, :, :].mean()\n"," std_test[i] = X_test[i, :, :].std()\n","\n"," # resulting normalized test images\n"," mean_val_test = mean_test.mean()\n"," std_val_test = std_test.mean()\n"," X_test_norm = np.zeros(X_test.shape, dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)\n"," \n"," # Reshaping\n"," X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)\n","\n"," # Reshaping labels\n"," Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)\n"," Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)\n","\n"," # Save datasets to a matfile to open later in matlab\n"," mdict = {\"mean_test\": mean_val_test, \"std_test\": std_val_test, \"upsampling_factor\": upsampling_factor, \"Normalization factor\": L2_weighting_factor}\n"," sio.savemat(os.path.join(modelPath,\"model_metadata.mat\"), mdict)\n","\n","\n"," # Set the dimensions ordering according to tensorflow consensous\n"," # K.set_image_dim_ordering('tf')\n"," K.set_image_data_format('channels_last')\n","\n"," # Save the model weights after each epoch if the validation loss decreased\n"," checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,\"weights_best.hdf5\"), verbose=1,\n"," save_best_only=True)\n","\n"," # Change learning when loss reaches a plataeu\n"," change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)\n"," \n"," # Model building and complitation\n"," model = buildModel((psize, psize, 1), initial_learning_rate = initial_learning_rate)\n"," model.summary()\n","\n"," # Load pretrained model\n"," if not pretrained_model_path:\n"," print('Using random initial model weights.')\n"," else:\n"," print('Loading model weights from '+pretrained_model_path)\n"," model.load_weights(pretrained_model_path)\n"," \n"," # Create an image data generator for real time data augmentation\n"," datagen = ImageDataGenerator(\n"," featurewise_center=False, # set input mean to 0 over the dataset\n"," samplewise_center=False, # set each sample mean to 0\n"," featurewise_std_normalization=False, # divide inputs by std of the dataset\n"," samplewise_std_normalization=False, # divide each input by its std\n"," zca_whitening=False, # apply ZCA whitening\n"," rotation_range=0., # randomly rotate images in the range (degrees, 0 to 180)\n"," width_shift_range=0., # randomly shift images horizontally (fraction of total width)\n"," height_shift_range=0., # randomly shift images vertically (fraction of total height)\n"," zoom_range=0.,\n"," shear_range=0.,\n"," horizontal_flip=False, # randomly flip images\n"," vertical_flip=False, # randomly flip images\n"," fill_mode='constant',\n"," data_format=K.image_data_format())\n","\n"," # Fit the image generator on the training data\n"," datagen.fit(X_train_norm)\n"," \n"," # loss history recorder\n"," history = LossHistory()\n","\n"," # Inform user training begun\n"," print('-------------------------------')\n"," print('Training model...')\n","\n"," # Fit model on the batches generated by datagen.flow()\n"," train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size), \n"," steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1, \n"," validation_data=(X_test_norm, Y_test), \n"," callbacks=[history, checkpointer, change_lr]) \n","\n"," # Inform user training ended\n"," print('-------------------------------')\n"," print('Training Complete!')\n"," \n"," # Save the last model\n"," model.save(os.path.join(modelPath, 'weights_last.hdf5'))\n","\n"," # convert the history.history dict to a pandas DataFrame: \n"," lossData = pd.DataFrame(train_history.history) \n","\n"," if os.path.exists(os.path.join(modelPath,\"Quality Control\")):\n"," shutil.rmtree(os.path.join(modelPath,\"Quality Control\"))\n","\n"," os.makedirs(os.path.join(modelPath,\"Quality Control\"))\n","\n"," # The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(modelPath,\"Quality Control/training_evaluation.csv\")\n"," with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss','learning rate'])\n"," for i in range(len(train_history.history['loss'])):\n"," writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i], train_history.history['lr'][i]])\n","\n"," return\n","\n","\n","# Normalization functions from Martin Weigert used in CARE\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","# Multi-threaded Erf-based image construction\n","@njit(parallel=True)\n","def FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," erfImage = np.zeros((w, h))\n"," for ij in prange(w*h):\n"," j = int(ij/w)\n"," i = ij - j*w\n"," for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):\n"," # Don't bother if the emitter has photons <= 0 or if Sigma <= 0\n"," if (sigma > 0) and (photon > 0):\n"," S = sigma*math.sqrt(2)\n"," x = i*pixel_size - xc\n"," y = j*pixel_size - yc\n"," # Don't bother if the emitter is further than 4 sigma from the centre of the pixel\n"," if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:\n"," ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)\n"," ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)\n"," erfImage[j][i] += 0.25*photon*ErfX*ErfY\n"," return erfImage\n","\n","\n","@njit(parallel=True)\n","def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," locImage = np.zeros((image_size[0],image_size[1]) )\n"," n_locs = len(xc_array)\n","\n"," for e in prange(n_locs):\n"," locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1\n","\n"," return locImage\n","\n","\n","\n","def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n"," with Image.open(TIFFpath) as img:\n"," meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n","\n","\n"," # TIFF tags\n"," # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n"," # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n"," ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n"," width = meta_dict['ImageWidth'][0]\n"," height = meta_dict['ImageLength'][0]\n","\n"," xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n","\n"," if len(xResolution) == 1:\n"," xResolution = xResolution[0]\n"," elif len(xResolution) == 2:\n"," xResolution = xResolution[0]/xResolution[1]\n"," else:\n"," print('Image resolution not defined.')\n"," xResolution = 1\n","\n"," if ResolutionUnit == 2:\n"," # Units given are in inches\n"," pixel_size = 0.025*1e9/xResolution\n"," elif ResolutionUnit == 3:\n"," # Units given are in cm\n"," pixel_size = 0.01*1e9/xResolution\n"," else: \n"," # ResolutionUnit is therefore 1\n"," print('Resolution unit not defined. Assuming: um')\n"," pixel_size = 1e3/xResolution\n","\n"," if display:\n"," print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n"," print('Image size: '+str(width)+'x'+str(height))\n"," \n"," return (pixel_size, width, height)\n","\n","\n","def saveAsTIF(path, filename, array, pixel_size):\n"," \"\"\"\n"," Image saving using PIL to save as .tif format\n"," # Input \n"," path - path where it will be saved\n"," filename - name of the file to save (no extension)\n"," array - numpy array conatining the data at the required format\n"," pixel_size - physical size of pixels in nanometers (identical for x and y)\n"," \"\"\"\n","\n"," # print('Data type: '+str(array.dtype))\n"," if (array.dtype == np.uint16):\n"," mode = 'I;16'\n"," elif (array.dtype == np.uint32):\n"," mode = 'I'\n"," else:\n"," mode = 'F'\n","\n"," # Rounding the pixel size to the nearest number that divides exactly 1cm.\n"," # Resolution needs to be a rational number --> see TIFF format\n"," # pixel_size = 10000/(round(10000/pixel_size))\n","\n"," if len(array.shape) == 2:\n"," im = Image.fromarray(array)\n"," im.save(os.path.join(path, filename+'.tif'),\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n","\n"," elif len(array.shape) == 3:\n"," imlist = []\n"," for frame in array:\n"," imlist.append(Image.fromarray(frame))\n","\n"," imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,\n"," append_images=imlist[1:],\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n"," return\n","\n","\n","\n","\n","class Maximafinder(Layer):\n"," def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):\n"," super(Maximafinder, self).__init__(**kwargs)\n"," self.thresh = tf.constant(thresh, dtype=tf.float32)\n"," self.nhood = neighborhood_size\n"," self.use_local_avg = use_local_avg\n","\n"," def build(self, input_shape):\n"," if self.use_local_avg is True:\n"," self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n","\n"," def call(self, inputs):\n","\n"," # local maxima positions\n"," max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)\n"," cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)\n"," indices = tf.where(cond)\n"," bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]\n"," confidence = tf.gather_nd(inputs, indices)\n","\n"," # local CoG estimator\n"," if self.use_local_avg:\n"," x_image = K.conv2d(inputs, self.kernel_x, padding='same')\n"," y_image = K.conv2d(inputs, self.kernel_y, padding='same')\n"," sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')\n"," confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)\n"," x_local = tf.math.divide(tf.gather_nd(x_image, indices),tf.gather_nd(sum_image, indices))\n"," y_local = tf.math.divide(tf.gather_nd(y_image, indices),tf.gather_nd(sum_image, indices))\n"," xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)\n"," else:\n"," xind = tf.cast(xind, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32)\n"," \n"," return bind, xind, yind, confidence\n","\n"," def get_config(self):\n","\n"," # Implement get_config to enable serialization. This is optional.\n"," base_config = super(Maximafinder, self).get_config()\n"," config = {}\n"," return dict(list(base_config.items()) + list(config.items()))\n","\n","\n","\n","# ------------------------------- Prediction with postprocessing function-------------------------------\n","def batchFramePredictionLocalization(dataPath, filename, modelPath, savePath, batch_size=1, thresh=0.1, neighborhood_size=3, use_local_avg = False, pixel_size = None):\n"," \"\"\"\n"," This function tests a trained model on the desired test set, given the \n"," tiff stack of test images, learned weights, and normalization factors.\n"," \n"," # Inputs\n"," dataPath - the path to the folder containing the tiff stack(s) to run prediction on \n"," filename - the name of the file to process\n"," modelPath - the path to the folder containing the weights file and the mean and standard deviation file generated in train_model\n"," savePath - the path to the folder where to save the prediction\n"," batch_size. - the number of frames to predict on for each iteration\n"," thresh - threshoold percentage from the maximum of the gaussian scaling\n"," neighborhood_size - the size of the neighborhood for local maxima finding\n"," use_local_average - Boolean whether to perform local averaging or not\n"," \"\"\"\n"," \n"," # load mean and std\n"," matfile = sio.loadmat(os.path.join(modelPath,'model_metadata.mat'))\n"," test_mean = np.array(matfile['mean_test'])\n"," test_std = np.array(matfile['std_test']) \n"," upsampling_factor = np.array(matfile['upsampling_factor'])\n"," upsampling_factor = upsampling_factor.item() # convert to scalar\n"," L2_weighting_factor = np.array(matfile['Normalization factor'])\n"," L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n","\n"," # Read in the raw file\n"," Images = io.imread(os.path.join(dataPath, filename))\n"," if pixel_size == None:\n"," pixel_size, _, _ = getPixelSizeTIFFmetadata(os.path.join(dataPath, filename), display=True)\n"," pixel_size_hr = pixel_size/upsampling_factor\n","\n"," # get dataset dimensions\n"," (nFrames, M, N) = Images.shape\n"," print('Input image is '+str(N)+'x'+str(M)+' with '+str(nFrames)+' frames.')\n","\n"," # Build the model for a bigger image\n"," model = buildModel((upsampling_factor*M, upsampling_factor*N, 1))\n","\n"," # Load the trained weights\n"," model.load_weights(os.path.join(modelPath,'weights_best.hdf5'))\n","\n"," # add a post-processing module\n"," max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, use_local_avg)\n","\n"," # Initialise the results: lists will be used to collect all the localizations\n"," frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []\n","\n"," # Initialise the results\n"," Prediction = np.zeros((M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n"," Widefield = np.zeros((M, N), dtype=np.float32)\n","\n"," # run model in batches\n"," n_batches = math.ceil(nFrames/batch_size)\n"," for b in tqdm(range(n_batches)):\n","\n"," nF = min(batch_size, nFrames - b*batch_size)\n"," Images_norm = np.zeros((nF, M, N),dtype=np.float32)\n"," Images_upsampled = np.zeros((nF, M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n","\n"," # Upsampling using a simple nearest neighbor interp and calculating - MULTI-THREAD this?\n"," for f in range(nF):\n"," Images_norm[f,:,:] = project_01(Images[b*batch_size+f,:,:])\n"," Images_norm[f,:,:] = normalize_im(Images_norm[f,:,:], test_mean, test_std)\n"," Images_upsampled[f,:,:] = np.kron(Images_norm[f,:,:], np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield += Images[b*batch_size+f,:,:]\n","\n"," # Reshaping\n"," Images_upsampled = np.expand_dims(Images_upsampled,axis=3)\n","\n"," # Run prediction and local amxima finding\n"," predicted_density = model.predict_on_batch(Images_upsampled)\n"," predicted_density[predicted_density < 0] = 0\n"," Prediction += predicted_density.sum(axis = 3).sum(axis = 0)\n","\n"," bind, xind, yind, confidence = max_layer(predicted_density)\n"," \n"," # normalizing the confidence by the L2_weighting_factor\n"," confidence /= L2_weighting_factor \n","\n"," # turn indices to nms and append to the results\n"," xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n"," frmind = (bind.numpy() + b*batch_size + 1).tolist()\n"," xind = xind.numpy().tolist()\n"," yind = yind.numpy().tolist()\n"," confidence = confidence.numpy().tolist()\n"," frame_number_list += frmind\n"," x_nm_list += xind\n"," y_nm_list += yind\n"," confidence_au_list += confidence\n","\n"," # Open and create the csv file that will contain all the localizations\n"," if use_local_avg:\n"," ext = '_avg'\n"," else:\n"," ext = '_max'\n"," with open(os.path.join(savePath, 'Localizations_' + os.path.splitext(filename)[0] + ext + '.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n"," locs = list(zip(frame_number_list, x_nm_list, y_nm_list, confidence_au_list))\n"," writer.writerows(locs)\n","\n"," # Save the prediction and widefield image\n"," Widefield = np.kron(Widefield, np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield = np.float32(Widefield)\n","\n"," # io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'), Prediction)\n"," # io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'), Widefield)\n","\n"," saveAsTIF(savePath, 'Predicted_'+os.path.splitext(filename)[0], Prediction, pixel_size_hr)\n"," saveAsTIF(savePath, 'Widefield_'+os.path.splitext(filename)[0], Widefield, pixel_size_hr)\n","\n","\n"," return\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n","\n","\n","\n","def list_files(directory, extension):\n"," return (f for f in os.listdir(directory) if f.endswith('.' + extension))\n","\n","\n","# @njit(parallel=True)\n","def subPixelMaxLocalization(array, method = 'CoM', patch_size = 3):\n"," xMaxInd, yMaxInd = np.unravel_index(array.argmax(), array.shape, order='C')\n"," centralPatch = XC[(xMaxInd-patch_size):(xMaxInd+patch_size+1),(yMaxInd-patch_size):(yMaxInd+patch_size+1)]\n","\n"," if (method == 'MAX'):\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n","\n"," elif (method == 'CoM'):\n"," x0 = 0\n"," y0 = 0\n"," S = 0\n"," for xy in range(patch_size*patch_size):\n"," y = math.floor(xy/patch_size)\n"," x = xy - y*patch_size\n"," x0 += x*array[x,y]\n"," y0 += y*array[x,y]\n"," S = array[x,y]\n"," \n"," x0 = x0/S - patch_size/2 + xMaxInd\n"," y0 = y0/S - patch_size/2 + yMaxInd\n"," \n"," elif (method == 'Radiality'):\n"," # Not implemented yet\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n"," \n"," return (x0, y0)\n","\n","\n","@njit(parallel=True)\n","def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):\n"," n_locs = xc_array.shape[0]\n"," xc_array_Corr = np.empty(n_locs)\n"," yc_array_Corr = np.empty(n_locs)\n"," \n"," for loc in prange(n_locs):\n"," xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc]]\n"," yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc]]\n","\n"," return (xc_array_Corr, yc_array_Corr)\n","\n","\n","print('--------------------------------')\n","print('DeepSTORM installation complete.')\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, raw_data = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Deep-STORM'\n"," #model_name = 'little_CARE_test'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hours)+ \"hour(s) \"+str(minutes)+\"min(s) \"+str(round(seconds))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n"," if raw_data == True:\n"," shape = (M,N)\n"," else:\n"," shape = (int(FOV_size/pixel_size),int(FOV_size/pixel_size))\n"," #dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. The models was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(180, 5, txt = text, align='L')\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if raw_data==False:\n"," simul_text = 'The training dataset was created in the notebook using the following simulation settings:'\n"," pdf.cell(200, 5, txt=simul_text, align='L')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
SettingSimulated Value
FOV_size{0}
pixel_size{1}
ADC_per_photon_conversion{2}
ReadOutNoise_ADC{3}
ADC_offset{4}
emitter_density{5}
emitter_density_std{6}
number_of_frames{7}
sigma{8}
sigma_std{9}
n_photons{10}
n_photons_std{11}
\n"," \"\"\".format(FOV_size, pixel_size, ADC_per_photon_conversion, ReadOutNoise_ADC, ADC_offset, emitter_density, emitter_density_std, number_of_frames, sigma, sigma_std, n_photons, n_photons_std)\n"," pdf.write_html(html)\n"," else:\n"," simul_text = 'The training dataset was simulated using ThunderSTORM and loaded into the notebook.'\n"," pdf.multi_cell(190, 5, txt=simul_text, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," #pdf.ln(1)\n"," #pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'ImageData_path', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = ImageData_path, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'LocalizationData_path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = LocalizationData_path, align = 'L')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'pixel_size:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = str(pixel_size), align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," # if Use_Default_Advanced_Parameters:\n"," # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used to generate patches:')\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(str(patch_size)+'x'+str(patch_size), upsampling_factor, num_patches_per_frame, min_number_of_emitters_per_patch, max_num_patches, gaussian_sigma, Automatic_normalization, L2_weighting_factor)\n"," pdf.write_html(html)\n"," pdf.ln(3)\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n","
Patch ParameterValue
patch_size{0}
upsampling_factor{1}
num_patches_per_frame{2}
min_number_of_emitters_per_patch{3}
max_num_patches{4}
gaussian_sigma{5}
Automatic_normalization{6}
L2_weighting_factor{7}
\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
Training ParameterValue
number_of_epochs{0}
batch_size{1}
number_of_steps{2}
percentage_validation{3}
initial_learning_rate{4}
\n"," \"\"\".format(number_of_epochs,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," # pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training Images', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_DeepSTORM2D.png').shape\n"," pdf.image('/content/TrainingDataExample_DeepSTORM2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Deep-STORM'\n"," #model_name = os.path.basename(full_QC_model_path)\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+os.path.basename(QC_model_path)+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," if os.path.exists(savePath+'/lossCurvePlots.png'):\n"," exp_size = io.imread(savePath+'/lossCurvePlots.png').shape\n"," pdf.image(savePath+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(savePath+'/QC_example_data.png').shape\n"," pdf.image(savePath+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(savePath+'/'+os.path.basename(QC_model_path)+'_QC_metrics.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n","\n","\n"," print('------------------------------')\n"," print('QC PDF report exported as '+savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n","\n","\n","\n","# Exporting requirements.txt for local run\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vu8f5NGJkJos"},"source":["\n","# **3. Generate patches for training**\n","---\n","\n","For Deep-STORM the training data can be obtained in two ways:\n","* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)\n","* Directly simulated in this notebook (**using Section 3.1.b**)\n"]},{"cell_type":"markdown","metadata":{"id":"WSV8xnlynp0l"},"source":["## **3.1.a Load training data**\n","---\n","\n","Here you can load your simulated data along with its corresponding localization file.\n","* The `pixel_size` is defined in nanometer (nm). "]},{"cell_type":"code","metadata":{"id":"CT6SNcfNg6j0","cellView":"form"},"source":["#@markdown ##Load raw data\n","\n","load_raw_data = True\n","\n","# Get user input\n","ImageData_path = \"\" #@param {type:\"string\"}\n","LocalizationData_path = \"\" #@param {type: \"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size,_,_ = getPixelSizeTIFFmetadata(ImageData_path, True)\n","\n","# load the tiff data\n","Images = io.imread(ImageData_path)\n","# get dataset dimensions\n","if len(Images.shape) == 3:\n"," (number_of_frames, M, N) = Images.shape\n","elif len(Images.shape) == 2:\n"," (M, N) = Images.shape\n"," number_of_frames = 1\n","print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')\n","\n","# Interactive display of the stack\n","def scroll_in_time(frame):\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source at frame = ' + str(frame))\n"," plt.axis('off');\n","\n","if number_of_frames > 1:\n"," interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","else:\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images, interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source')\n"," plt.axis('off');\n","\n","# Load the localization file and display the first\n","LocData = pd.read_csv(LocalizationData_path, index_col=0)\n","LocData.tail()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9xE5GeYiks9"},"source":["## **3.1.b Simulate training data**\n","---\n","This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view. \n","The assumptions are as follows:\n","\n","* Gaussian Point Spread Function (PSF) with standard deviation defined by `Sigma`. The nominal value of `sigma` can be evaluated using `sigma = 0.21 x Lambda / NA`. \n","* Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.\n","* The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC\n","* The `emitter_density` is defined as the number of emitters / um^2 on any given frame. Variability in the emitter density can be applied by adjusting `emitter_density_std`. The latter parameter represents the standard deviation of the normal distribution that the density is drawn from for each individual frame. `emitter_density` **is defined in number of emitters / um^2**.\n","* The `n_photons` and `sigma` can additionally include some Gaussian variability by setting `n_photons_std` and `sigma_std`.\n","\n","Important note:\n","- All dimensions are in nanometer (e.g. `FOV_size` = 6400 represents a field of view of 6.4 um x 6.4 um).\n","\n"]},{"cell_type":"code","metadata":{"id":"sQyLXpEhitsg","cellView":"form"},"source":["load_raw_data = False\n","\n","# ---------------------------- User input ----------------------------\n","#@markdown Run the simulation\n","#@markdown --- \n","#@markdown Camera settings: \n","FOV_size = 6400#@param {type:\"number\"}\n","pixel_size = 100#@param {type:\"number\"}\n","ADC_per_photon_conversion = 1 #@param {type:\"number\"}\n","ReadOutNoise_ADC = 4.5#@param {type:\"number\"}\n","ADC_offset = 50#@param {type:\"number\"}\n","\n","#@markdown Acquisition settings: \n","emitter_density = 6#@param {type:\"number\"}\n","emitter_density_std = 0#@param {type:\"number\"}\n","\n","number_of_frames = 20#@param {type:\"integer\"}\n","\n","sigma = 110 #@param {type:\"number\"}\n","sigma_std = 5 #@param {type:\"number\"}\n","# NA = 1.1 #@param {type:\"number\"}\n","# wavelength = 800#@param {type:\"number\"}\n","# wavelength_std = 150#@param {type:\"number\"}\n","n_photons = 2250#@param {type:\"number\"}\n","n_photons_std = 250#@param {type:\"number\"}\n","\n","\n","# ---------------------------- Variable initialisation ----------------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","print('-----------------------------------------------------------')\n","n_molecules = emitter_density*FOV_size*FOV_size/10**6\n","n_molecules_std = emitter_density_std*FOV_size*FOV_size/10**6\n","print('Number of molecules / FOV: '+str(round(n_molecules,2))+' +/- '+str((round(n_molecules_std,2))))\n","\n","# sigma = 0.21*wavelength/NA\n","# sigma_std = 0.21*wavelength_std/NA\n","# print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')\n","\n","M = N = round(FOV_size/pixel_size)\n","FOV_size = M*pixel_size\n","print('Final image size: '+str(M)+'x'+str(M)+' ('+str(round(FOV_size/1000, 3))+'um x'+str(round(FOV_size/1000,3))+' um)')\n","\n","np.random.seed(1)\n","display_upsampling = 8 # used to display the loc map here\n","NoiseFreeImages = np.zeros((number_of_frames, M, M))\n","locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))\n","\n","frames = []\n","all_xloc = []\n","all_yloc = []\n","all_photons = []\n","all_sigmas = []\n","\n","# ---------------------------- Main simulation loop ----------------------------\n","print('-----------------------------------------------------------')\n","for f in tqdm(range(number_of_frames)):\n"," \n"," # Define the coordinates of emitters by randomly distributing them across the FOV\n"," n_mol = int(max(round(np.random.normal(n_molecules, n_molecules_std, size=1)[0]), 0))\n"," x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," photon_array = np.random.normal(n_photons, n_photons_std, size=n_mol)\n"," sigma_array = np.random.normal(sigma, sigma_std, size=n_mol)\n"," # x_c = np.linspace(0,3000,5)\n"," # y_c = np.linspace(0,3000,5)\n","\n"," all_xloc += x_c.tolist()\n"," all_yloc += y_c.tolist()\n"," frames += ((f+1)*np.ones(x_c.shape[0])).tolist()\n"," all_photons += photon_array.tolist()\n"," all_sigmas += sigma_array.tolist()\n","\n"," locImage[f] = FromLoc2Image_SimpleHistogram(x_c, y_c, image_size = (N*display_upsampling, M*display_upsampling), pixel_size = pixel_size/display_upsampling)\n","\n"," # # Get the approximated locations according to the grid pixel size\n"," # Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]\n"," # Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]\n","\n"," # # Build Localization image\n"," # for (r,c) in zip(Rhr_emitters, Chr_emitters):\n"," # locImage[f][r][c] += 1\n","\n"," NoiseFreeImages[f] = FromLoc2Image_Erf(x_c, y_c, photon_array, sigma_array, image_size = (M,M), pixel_size = pixel_size)\n","\n","\n","# ---------------------------- Create DataFrame fof localization file ----------------------------\n","# Table with localization info as dataframe output\n","LocData = pd.DataFrame()\n","LocData[\"frame\"] = frames\n","LocData[\"x [nm]\"] = all_xloc\n","LocData[\"y [nm]\"] = all_yloc\n","LocData[\"Photon #\"] = all_photons\n","LocData[\"Sigma [nm]\"] = all_sigmas\n","LocData.index += 1 # set indices to start at 1 and not 0 (same as ThunderSTORM)\n","\n","\n","# ---------------------------- Estimation of SNR ----------------------------\n","n_frames_for_SNR = 100\n","M_SNR = 10\n","x_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","y_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","photon_array = np.random.normal(n_photons, n_photons_std, size=n_frames_for_SNR)\n","sigma_array = np.random.normal(sigma, sigma_std, size=n_frames_for_SNR)\n","\n","SNR = np.zeros(n_frames_for_SNR)\n","for i in range(n_frames_for_SNR):\n"," SingleEmitterImage = FromLoc2Image_Erf(np.array([x_c[i]]), np.array([x_c[i]]), np.array([photon_array[i]]), np.array([sigma_array[i]]), (M_SNR, M_SNR), pixel_size)\n"," Signal_photon = np.max(SingleEmitterImage)\n"," Noise_photon = math.sqrt((ReadOutNoise_ADC/ADC_per_photon_conversion)**2 + Signal_photon)\n"," SNR[i] = Signal_photon/Noise_photon\n","\n","print('SNR: '+str(round(np.mean(SNR),2))+' +/- '+str(round(np.std(SNR),2)))\n","# ---------------------------- ----------------------------\n","\n","\n","# Table with info\n","simParameters = pd.DataFrame()\n","simParameters[\"FOV size (nm)\"] = [FOV_size]\n","simParameters[\"Pixel size (nm)\"] = [pixel_size]\n","simParameters[\"ADC/photon\"] = [ADC_per_photon_conversion]\n","simParameters[\"Read-out noise (ADC)\"] = [ReadOutNoise_ADC]\n","simParameters[\"Constant offset (ADC)\"] = [ADC_offset]\n","\n","simParameters[\"Emitter density (emitters/um^2)\"] = [emitter_density]\n","simParameters[\"STD of emitter density (emitters/um^2)\"] = [emitter_density_std]\n","simParameters[\"Number of frames\"] = [number_of_frames]\n","# simParameters[\"NA\"] = [NA]\n","# simParameters[\"Wavelength (nm)\"] = [wavelength]\n","# simParameters[\"STD of wavelength (nm)\"] = [wavelength_std]\n","simParameters[\"Sigma (nm))\"] = [sigma]\n","simParameters[\"STD of Sigma (nm))\"] = [sigma_std]\n","simParameters[\"Number of photons\"] = [n_photons]\n","simParameters[\"STD of number of photons\"] = [n_photons_std]\n","simParameters[\"SNR\"] = [np.mean(SNR)]\n","simParameters[\"STD of SNR\"] = [np.std(SNR)]\n","\n","\n","# ---------------------------- Finish simulation ----------------------------\n","# Calculating the noisy image\n","Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset\n","Images[Images <= 0] = 0\n","\n","# Convert to 16-bit or 32-bits integers\n","if Images.max() < (2**16-1):\n"," Images = Images.astype(np.uint16)\n","else:\n"," Images = Images.astype(np.uint32)\n","\n","\n","# ---------------------------- Display ----------------------------\n","# Displaying the time elapsed for simulation\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n","\n","\n","# Interactively display the results using Widgets\n","def scroll_in_time(frame):\n"," f = plt.figure(figsize=(18,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)\n"," plt.title('Localization image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noise-free simulation')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noisy simulation')\n"," plt.axis('off');\n","\n","interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","\n","# Display the head of the dataframe with localizations\n","LocData.tail()\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pz7RfSuoeJeq","cellView":"form"},"source":["# @markdown ---\n","# @markdown #Play this cell to save the simulated stack\n","# @markdown ####Please select a path to the folder where to save the simulated data. It is not necessary to save the data to run the training, but keeping the simulated for your own record can be useful to check its validity.\n","Save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(Save_path):\n"," os.makedirs(Save_path)\n"," print('Folder created.')\n","else:\n"," print('Training data already exists in folder: Data overwritten.')\n","\n","saveAsTIF(Save_path, 'SimulatedDataset', Images, pixel_size)\n","# io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)\n","LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))\n","simParameters.to_csv(os.path.join(Save_path, 'SimulatedParameters.csv'))\n","print('Training dataset saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K_8e3kE-JhVY"},"source":["## **3.2. Generate training patches**\n","---\n","\n","Training patches need to be created from the training data generated above. \n","* The `patch_size` needs to give sufficient contextual information and for most cases a `patch_size` of 26 (corresponding to patches of 26x26 pixels) works fine. **DEFAULT: 26**\n","* The `upsampling_factor` defines the effective magnification of the final super-resolved image compared to the input image (this is called magnification in ThunderSTORM). This is used to generate the super-resolved patches as target dataset. Using an `upsampling_factor` of 16 will require the use of more memory and it may be necessary to decreae the `patch_size` to 16 for example. **DEFAULT: 8**\n","* The `num_patches_per_frame` defines the number of patches extracted from each frame generated in section 3.1. **DEFAULT: 500**\n","* The `min_number_of_emitters_per_patch` defines the minimum number of emitters that need to be present in the patch to be a valid patch. An empty patch does not contain useful information for the network to learn from. **DEFAULT: 7**\n","* The `max_num_patches` defines the maximum number of patches to generate. Fewer may be generated depending on how many pacthes are rejected and how many frames are available. **DEFAULT: 10000**\n","* The `gaussian_sigma` defines the Gaussian standard deviation (in magnified pixels) applied to generate the super-resolved target image. **DEFAULT: 1**\n","* The `L2_weighting_factor` is a normalization factor used in the loss function. It helps balancing the loss from the L2 norm. When using higher densities, this factor should be decreased and vice-versa. This factor can be autimatically calculated using an empiraical formula. **DEFAULT: 100**\n","\n"]},{"cell_type":"code","metadata":{"id":"AsNx5KzcFNvC","cellView":"form"},"source":["#@markdown ## **Provide patch parameters**\n","\n","\n","# -------------------- User input --------------------\n","patch_size = 26 #@param {type:\"integer\"}\n","upsampling_factor = 8 #@param [\"4\", \"8\", \"16\"] {type:\"raw\"}\n","num_patches_per_frame = 500#@param {type:\"integer\"}\n","min_number_of_emitters_per_patch = 7#@param {type:\"integer\"}\n","max_num_patches = 10000#@param {type:\"integer\"}\n","gaussian_sigma = 1#@param {type:\"integer\"}\n","\n","#@markdown Estimate the optimal normalization factor automatically?\n","Automatic_normalization = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, it will use the following value:\n","L2_weighting_factor = 100 #@param {type:\"number\"}\n","\n","\n","# -------------------- Prepare variables --------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# Initialize some parameters\n","pixel_size_hr = pixel_size/upsampling_factor # in nm\n","n_patches = min(number_of_frames*num_patches_per_frame, max_num_patches)\n","patch_size = patch_size*upsampling_factor\n","\n","# Dimensions of the high-res grid\n","Mhr = upsampling_factor*M # in pixels\n","Nhr = upsampling_factor*N # in pixels\n","\n","# Initialize the training patches and labels\n","patches = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","spikes = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","heatmaps = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","\n","# Run over all frames and construct the training examples\n","k = 1 # current patch count\n","skip_counter = 0 # number of dataset skipped due to low density\n","id_start = 0 # id position in LocData for current frame\n","print('Generating '+str(n_patches)+' patches of '+str(patch_size)+'x'+str(patch_size))\n","\n","n_locs = len(LocData.index)\n","print('Total number of localizations: '+str(n_locs))\n","density = n_locs/(M*N*number_of_frames*(0.001*pixel_size)**2)\n","print('Density: '+str(round(density,2))+' locs/um^2')\n","n_locs_per_patch = patch_size**2*density\n","\n","if Automatic_normalization:\n"," # This empirical formulae attempts to balance the loss L2 function between the background and the bright spikes\n"," # A value of 100 was originally chosen to balance L2 for a patch size of 2.6x2.6^2 0.1um pixel size and density of 3 (hence the 20.28), at upsampling_factor = 8\n"," L2_weighting_factor = 100/math.sqrt(min(n_locs_per_patch, min_number_of_emitters_per_patch)*8**2/(upsampling_factor**2*20.28))\n"," print('Normalization factor: '+str(round(L2_weighting_factor,2)))\n","\n","# -------------------- Patch generation loop --------------------\n","\n","print('-----------------------------------------------------------')\n","for (f, thisFrame) in enumerate(tqdm(Images)):\n","\n"," # Upsample the frame\n"," upsampledFrame = np.kron(thisFrame, np.ones((upsampling_factor,upsampling_factor)))\n"," # Read all the provided high-resolution locations for current frame\n"," DataFrame = LocData[LocData['frame'] == f+1].copy()\n","\n"," # Get the approximated locations according to the high-res grid pixel size\n"," Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," id_start += len(DataFrame.index)\n","\n"," # Build Localization image\n"," LocImage = np.zeros((Mhr,Nhr))\n"," LocImage[(Rhr_emitters, Chr_emitters)] = 1\n","\n"," # Here, there's a choice between the original Gaussian (classification approach) and using the erf function\n"," HeatMapImage = L2_weighting_factor*gaussian_filter(LocImage, gaussian_sigma) \n"," # HeatMapImage = L2_weighting_factor*FromLoc2Image_MultiThreaded(np.array(list(DataFrame['x [nm]'])), np.array(list(DataFrame['y [nm]'])), \n"," # np.ones(len(DataFrame.index)), pixel_size_hr*gaussian_sigma*np.ones(len(DataFrame.index)), \n"," # Mhr, pixel_size_hr)\n"," \n","\n"," # Generate random position for the top left corner of the patch\n"," xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)\n"," yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)\n","\n"," for c in range(len(xc)):\n"," if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:\n"," skip_counter += 1\n"," continue\n"," \n"," else:\n"," # Limit maximal number of training examples to 15k\n"," if k > max_num_patches:\n"," break\n"," else:\n"," # Assign the patches to the right part of the images\n"," patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," heatmaps[k-1] = HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," k += 1 # increment current patch count\n","\n","# Remove the empty data\n","patches = patches[:k-1]\n","spikes = spikes[:k-1]\n","heatmaps = heatmaps[:k-1]\n","n_patches = k-1\n","\n","# -------------------- Failsafe --------------------\n","# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM\n","if ((k-1) < 5000):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(bcolors.WARNING+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+bcolors.NORMAL)\n","\n","\n","\n","# -------------------- Displays --------------------\n","print('Number of patches skipped due to low density: '+str(skip_counter))\n","# dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB\n","# print('Size of patches: '+str(dataSize)+' MB')\n","print(str(n_patches)+' patches were generated.')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display patches interactively with a slider\n","def scroll_patches(patch):\n"," f = plt.figure(figsize=(16,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(patches[patch-1], interpolation='nearest', cmap='gray')\n"," plt.title('Raw data (frame #'+str(patch)+')')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(heatmaps[patch-1], interpolation='nearest')\n"," plt.title('Heat map')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(spikes[patch-1], interpolation='nearest')\n"," plt.title('Localization map')\n"," plt.axis('off');\n"," \n"," plt.savefig('/content/TrainingDataExample_DeepSTORM2D.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DSjXFMevK7Iz"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"hVeyKU0MdAPx"},"source":["## **4.1. Select your paths and parameters**\n","\n","---\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for ~100 epochs. Evaluate the performance after training (see 5). **Default value: 80**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. **If this value is set to 0**, by default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 30** \n","\n","**`initial_learning_rate`:** This parameter represents the initial value to be used as learning rate in the optimizer. **Default value: 0.001**"]},{"cell_type":"code","metadata":{"id":"oa5cDZ7f_PF6","cellView":"form"},"source":["#@markdown ###Path to training images and parameters\n","\n","model_path = \"\" #@param {type: \"string\"} \n","model_name = \"\" #@param {type: \"string\"} \n","number_of_epochs = 80#@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"integer\"}\n","\n","number_of_steps = 0#@param {type:\"integer\"}\n","percentage_validation = 30 #@param {type:\"number\"}\n","initial_learning_rate = 0.001 #@param {type:\"number\"}\n","\n","\n","percentage_validation /= 100\n","if number_of_steps == 0: \n"," number_of_steps = int((1-percentage_validation)*n_patches/batch_size)\n"," print('Number of steps: '+str(number_of_steps))\n","\n","# Pretrained model path initialised here so next cell does not need to be run\n","h5_file_path = ''\n","Use_pretrained_model = False\n","\n","if not ('patches' in locals()):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(WARNING+'!! WARNING: No patches were found in memory currently. !!')\n","\n","Save_path = os.path.join(model_path, model_name)\n","if os.path.exists(Save_path):\n"," print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","print('-----------------------------')\n","print('Training parameters set.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WIyEvQBWLp9n"},"source":["\n","## **4.2. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deep-STORM 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"oHL5g0w8LqR0","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.hdf5 pretrained model does not exist'+bcolors.NORMAL)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead.'+bcolors.NORMAL)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+bcolors.NORMAL)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print('No pretrained network will be used.')\n"," h5_file_path = ''\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OADNcie-LHxA"},"source":["## **4.4. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"id":"qDgMu_mAK8US","cellView":"form"},"source":["#@markdown ##Start training\n","\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(Save_path):\n"," shutil.rmtree(Save_path)\n","\n","# Create the model folder!\n","os.makedirs(Save_path)\n","\n","# Export pdf summary \n","pdf_export(raw_data = load_raw_data, pretrained_model = Use_pretrained_model)\n","\n","# Let's go !\n","train_model(patches, heatmaps, Save_path, \n"," steps_per_epoch=number_of_steps, epochs=number_of_epochs, batch_size=batch_size,\n"," upsampling_factor = upsampling_factor,\n"," validation_split = percentage_validation,\n"," initial_learning_rate = initial_learning_rate, \n"," pretrained_model_path = h5_file_path,\n"," L2_weighting_factor = L2_weighting_factor)\n","\n","# # Show info about the GPU memory useage\n","# !nvidia-smi\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# export pdf after training to update the existing document\n","pdf_export(trained = True, raw_data = load_raw_data, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4N7-ShZpLhwr"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"JDRsm7uKoBa-","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter `model_name` (see section 4.1). Provide the name of this folder as `QC_model_path` . \n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(QC_model_path):\n"," print(\"The \"+os.path.basename(QC_model_path)+\" model will be evaluated\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Gw7KaHZUoHC4"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qUc-JMOcoGNZ","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'), bbox_inches='tight', pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"32eNQjFioQkY"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"QC_image_folder\" using teh corresponding localization data contained in \"QC_loc_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"dhlTnxC5lUZy","cellView":"form"},"source":["\n","# ------------------------ User input ------------------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","QC_image_folder = \"\" #@param{type:\"string\"}\n","QC_loc_folder = \"\" #@param{type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size_INPUT = None\n","else:\n"," pixel_size_INPUT = pixel_size\n","\n","\n","# ------------------------ QC analysis loop over provided dataset ------------------------\n","\n","savePath = os.path.join(QC_model_path, 'Quality Control')\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(os.path.join(savePath, os.path.basename(QC_model_path)+\"_QC_metrics.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"WF v. GT mSSIM\", \"Prediction v. GT NRMSE\",\"WF v. GT NRMSE\", \"Prediction v. GT PSNR\", \"WF v. GT PSNR\"])\n","\n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvWF_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvWF_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvWF_list = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n"," for (imageFilename, locFilename) in zip(list_files(QC_image_folder, 'tif'), list_files(QC_loc_folder, 'csv')):\n"," print('--------------')\n"," print(imageFilename)\n"," print(locFilename)\n","\n"," # Get the prediction\n"," batchFramePredictionLocalization(QC_image_folder, imageFilename, QC_model_path, savePath, pixel_size = pixel_size_INPUT)\n","\n"," # test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);\n"," thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))\n"," thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))\n","\n"," Mhr = thisPrediction.shape[0]\n"," Nhr = thisPrediction.shape[1]\n","\n"," if pixel_size_INPUT == None:\n"," pixel_size, N, M = getPixelSizeTIFFmetadata(os.path.join(QC_image_folder,imageFilename))\n","\n"," upsampling_factor = int(Mhr/M)\n"," print('Upsampling factor: '+str(upsampling_factor))\n"," pixel_size_hr = pixel_size/upsampling_factor # in nm\n","\n"," # Load the localization file and display the first\n"," LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)\n","\n"," x = np.array(list(LocData['x [nm]']))\n"," y = np.array(list(LocData['y [nm]']))\n"," locImage = FromLoc2Image_SimpleHistogram(x, y, image_size = (Mhr,Nhr), pixel_size = pixel_size_hr)\n","\n"," # Remove extension from filename\n"," imageFilename_no_extension = os.path.splitext(imageFilename)[0]\n","\n"," # io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)\n"," saveAsTIF(savePath, 'GT_image_'+imageFilename_no_extension, locImage, pixel_size_hr)\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)\n"," index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)\n","\n","\n"," # Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsPrediction_'+imageFilename_no_extension, img_SSIM_GTvsPrediction_32bit, pixel_size_hr)\n","\n","\n"," img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsWF_'+imageFilename_no_extension, img_SSIM_GTvsWF_32bit, pixel_size_hr)\n","\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsPrediction_'+imageFilename_no_extension, img_RSE_GTvsPrediction_32bit, pixel_size_hr)\n","\n"," img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsWF_'+imageFilename_no_extension, img_RSE_GTvsWF_32bit, pixel_size_hr)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)\n","\n"," writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])\n","\n"," # Collect values to display in dataframe output\n"," file_name_list.append(imageFilename)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvWF_list.append(index_SSIM_GTvsWF)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvWF_list.append(NRMSE_GTvsWF)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvWF_list.append(PSNR_GTvsWF)\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = file_name_list)\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list\n","pdResults[\"Wide-field v. GT mSSIM\"] = mSSIM_GvWF_list\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list\n","pdResults[\"Wide-field v. GT NRMSE\"] = NRMSE_GvWF_list\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list\n","pdResults[\"Wide-field v. GT PSNR\"] = PSNR_GvWF_list\n","\n","\n","# ------------------------ Display ------------------------\n","\n","print('--------------------------------------------')\n","@interact\n","def show_QC_results(file = list_files(QC_image_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,15))\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))\n"," plt.imshow(img_GT, norm = simple_norm(img_GT, percent = 99.5))\n"," plt.title('Target',fontsize=15)\n","\n"," # Wide-field\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield',fontsize=15)\n","\n"," #Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Prediction',fontsize=15)\n","\n"," #Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))\n"," imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Wide-field v. GT mSSIM\"],3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Prediction v. GT mSSIM\"],3)),fontsize=14)\n","\n"," #Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))\n"," imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Wide-field v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Wide-field v. GT PSNR\"],3)),fontsize=14)\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Prediction v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Prediction v. GT PSNR\"],3)),fontsize=14)\n"," plt.savefig(QC_model_path+'/Quality Control/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n","print('--------------------------------------------')\n","pdResults.head()\n","\n","# Export pdf wth summary of QC results\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yTRou0izLjhd"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"eAf8aBDmWTx7"},"source":["## **6.1 Generate image prediction and localizations from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the found localizations csv.\n","\n","**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 4**\n","\n","**`threshold`:** This paramter determines threshold for local maxima finding. The value is expected to reside in the range **[0,1]**. A higher `threshold` will result in less localizations. **DEFAULT: 0.1**\n","\n","**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**\n","\n","**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**\n"]},{"cell_type":"code","metadata":{"id":"7qn06T_A0lxf","cellView":"form"},"source":["\n","# ------------------------------- User input -------------------------------\n","#@markdown ### Data parameters\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value (in nm):\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","#@markdown ### Model parameters\n","#@markdown Do you want to use the model you just trained?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, please provide path to the model folder below\n","prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Prediction parameters\n","batch_size = 4#@param {type:\"integer\"}\n","\n","#@markdown ### Post processing parameters\n","threshold = 0.1#@param {type:\"number\"}\n","neighborhood_size = 3#@param {type:\"integer\"}\n","#@markdown Do you want to locally average the model output with CoG estimator ?\n","use_local_average = True #@param {type:\"boolean\"}\n","\n","\n","if get_pixel_size_from_file:\n"," pixel_size = None\n","\n","if (Use_the_current_trained_model): \n"," prediction_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(prediction_model_path):\n"," print(\"The \"+os.path.basename(prediction_model_path)+\" model will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","# inform user whether local averaging is being used\n","if use_local_average == True: \n"," print('Using local averaging')\n","\n","if not os.path.exists(Result_folder):\n"," print('Result folder was created.')\n"," os.makedirs(Result_folder)\n","\n","\n","# ------------------------------- Run predictions -------------------------------\n","\n","start = time.time()\n","#%% This script tests the trained fully convolutional network based on the \n","# saved training weights, and normalization created using train_model.\n","\n","if os.path.isdir(Data_folder): \n"," for filename in list_files(Data_folder, 'tif'):\n"," # run the testing/reconstruction process\n"," print(\"------------------------------------\")\n"," print(\"Running prediction on: \"+ filename)\n"," batchFramePredictionLocalization(Data_folder, filename, prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average,\n"," pixel_size = pixel_size)\n","\n","elif os.path.isfile(Data_folder):\n"," batchFramePredictionLocalization(os.path.dirname(Data_folder), os.path.basename(Data_folder), prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average, \n"," pixel_size = pixel_size)\n","\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------------------- Interactive display -------------------------------\n","\n","print('--------------------------------------------------------------------')\n","print('---------------------------- Previews ------------------------------')\n","print('--------------------------------------------------------------------')\n","\n","if os.path.isdir(Data_folder): \n"," @interact\n"," def show_QC_results(file = list_files(Data_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n","if os.path.isfile(Data_folder):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZekzexaPmzFZ"},"source":["## **6.2 Drift correction**\n","---\n","\n","The visualization above is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. The display is a preview without any drift correction applied. This section performs drift correction using cross-correlation between time bins to estimate the drift.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the image reconstructions used for the Drift Correction estmication (in **nm**). A smaller pixel size will be more precise but will take longer to compute. **DEFAULT: 20**\n","\n","**`number_of_bins`:** This parameter defines how many temporal bins are used across the full dataset. All localizations in each bins are used ot build an image. This image is used to find the drift with respect to the image obtained from the very first bin. A typical value would correspond to about 500 frames per bin. **DEFAULT: Total number of frames / 500**\n","\n","**`polynomial_fit_degree`:** The drift obtained for each temporal bins needs to be interpolated to every single frames. This is performed by polynomial fit, the degree of which is defined here. **DEFAULT: 4**\n","\n"," The drift-corrected localization data is automaticaly saved in the `save_path` folder."]},{"cell_type":"code","metadata":{"id":"hYtP_vh6mzUP","cellView":"form"},"source":["# @markdown ##Data parameters\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","# @markdown ##Drift correction parameters\n","visualization_pixel_size = 20#@param {type:\"number\"}\n","number_of_bins = 50#@param {type:\"integer\"}\n","polynomial_fit_degree = 4#@param {type:\"integer\"}\n","\n","# @markdown ##Saving parameters\n","save_path = '' #@param {type:\"string\"}\n","\n","\n","# Let's go !\n","start = time.time()\n","\n","# Get info from the raw file if selected\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","# Read the localizations in\n","LocData = pd.read_csv(Loc_file_path)\n","\n","# Calculate a few variables \n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","n_locs = len(LocData.index)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(n_locs))\n","\n","blocksize = math.ceil(nFrames/number_of_bins)\n","print('Number of frames per block: '+str(blocksize))\n","\n","blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()\n","xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n","# Preparing the Reference image\n","photon_array = np.ones(yc_array.shape[0])\n","sigma_array = np.ones(yc_array.shape[0])\n","ImageRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImagesRef = np.rot90(ImageRef, k=2)\n","\n","xDrift = np.zeros(number_of_bins)\n","yDrift = np.zeros(number_of_bins)\n","\n","filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","with open(os.path.join(save_path, filename_no_extension+\"_DriftCorrectionData.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"Block #\", \"x-drift [nm]\",\"y-drift [nm]\"])\n","\n"," for b in tqdm(range(number_of_bins)):\n","\n"," blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()\n"," xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n"," yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n"," photon_array = np.ones(yc_array.shape[0])\n"," sigma_array = np.ones(yc_array.shape[0])\n"," ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n"," XC = fftconvolve(ImagesRef, ImageBlock, mode = 'same')\n"," yDrift[b], xDrift[b] = subPixelMaxLocalization(XC, method = 'CoM')\n","\n"," # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)\n"," # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)\n"," writer.writerow([str(b), str((xDrift[b]-xDrift[0])*visualization_pixel_size), str((yDrift[b]-yDrift[0])*visualization_pixel_size)])\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","print('Fitting drift data...')\n","bin_number = np.arange(number_of_bins)*blocksize + blocksize/2\n","xDrift = (xDrift-xDrift[0])*visualization_pixel_size\n","yDrift = (yDrift-yDrift[0])*visualization_pixel_size\n","\n","xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)\n","yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)\n","\n","xDriftFit = np.poly1d(xDriftCoeff)\n","yDriftFit = np.poly1d(yDriftCoeff)\n","bins = np.arange(nFrames)\n","xDriftInterpolated = xDriftFit(bins)\n","yDriftInterpolated = yDriftFit(bins)\n","\n","\n","# ------------------ Displaying the image results ------------------\n","\n","plt.figure(figsize=(15,10))\n","plt.plot(bin_number,xDrift, 'r+', label='x-drift')\n","plt.plot(bin_number,yDrift, 'b+', label='y-drift')\n","plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')\n","plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')\n","plt.title('Cross-correlation estimated drift')\n","plt.ylabel('Drift [nm]')\n","plt.xlabel('Bin number')\n","plt.legend();\n","\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\", hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------ Actual drift correction -------------------\n","\n","print('Correcting localization data...')\n","xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)\n","frames = LocData['frame'].to_numpy(dtype=np.int32)\n","\n","\n","xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)\n","ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","\n","# ------------------ Displaying the imge results ------------------\n","plt.figure(figsize=(15,7.5))\n","# Raw\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(ImageRaw, norm = simple_norm(ImageRaw, percent = 99.5))\n","plt.title('Raw', fontsize=15);\n","# Corrected\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(ImageCorr, norm = simple_norm(ImageCorr, percent = 99.5))\n","plt.title('Corrected',fontsize=15);\n","\n","\n","# ------------------ Table with info -------------------\n","driftCorrectedLocData = pd.DataFrame()\n","driftCorrectedLocData['frame'] = frames\n","driftCorrectedLocData['x [nm]'] = xc_array_Corr\n","driftCorrectedLocData['y [nm]'] = yc_array_Corr\n","driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']\n","\n","driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))\n","print('-------------------------------')\n","print('Corrected localizations saved.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mzOuc-V7rB-r"},"source":["## **6.3 Visualization of the localizations**\n","---\n","\n","\n","The visualization in section 6.1 is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. This section performs visualization of the result by plotting the localizations as a simple histogram.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the final image reconstruction (in **nm**). **DEFAULT: 10**\n","\n","**`visualization_mode`:** This parameter defines what visualization method is used to visualize the final image. NOTES: The Integrated Gaussian can be quite slow. **DEFAULT: Simple histogram.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"876yIXnqq-nW","cellView":"form"},"source":["# @markdown ##Data parameters\n","Use_current_drift_corrected_localizations = True #@param {type:\"boolean\"}\n","# @markdown Otherwise provide a localization file path\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100#@param {type:\"number\"}\n","\n","# @markdown ##Visualization parameters\n","visualization_pixel_size = 10#@param {type:\"number\"}\n","visualization_mode = \"Simple histogram\" #@param [\"Simple histogram\", \"Integrated Gaussian (SLOW!)\"]\n","\n","if not Use_current_drift_corrected_localizations:\n"," filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","if Use_current_drift_corrected_localizations:\n"," LocData = driftCorrectedLocData\n","else:\n"," LocData = pd.read_csv(Loc_file_path)\n","\n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","\n","\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(len(LocData.index)))\n","\n","xc_array = LocData['x [nm]'].to_numpy()\n","yc_array = LocData['y [nm]'].to_numpy()\n","if (visualization_mode == 'Simple histogram'):\n"," locImage = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","elif (visualization_mode == 'Shifted histogram'):\n"," print(bcolors.WARNING+'Method not implemented yet!'+bcolors.NORMAL)\n"," locImage = np.zeros(image_size)\n","elif (visualization_mode == 'Integrated Gaussian (SLOW!)'):\n"," photon_array = np.ones(xc_array.shape)\n"," sigma_array = np.ones(xc_array.shape)\n"," locImage = FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display\n","plt.figure(figsize=(20,10))\n","plt.axis('off')\n","# plt.imshow(locImage, cmap='gray');\n","plt.imshow(locImage, norm = simple_norm(locImage, percent = 99.5));\n","\n","\n","LocData.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"PdOhWwMn1zIT","cellView":"form"},"source":["# @markdown ---\n","# @markdown #Play this cell to save the visualization\n","# @markdown ####Please select a path to the folder where to save the visualization.\n","save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(save_path):\n"," os.makedirs(save_path)\n"," print('Folder created.')\n","\n","saveAsTIF(save_path, filename_no_extension+'_Visualization', locImage, visualization_pixel_size)\n","print('Image saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1EszIF4Dkz_n"},"source":["## **6.4. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UgN-NooKk3nV"},"source":["\n","#**Thank you for using Deep-STORM 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv b/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv index 09601587..809bdcb8 100644 --- a/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv +++ b/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv @@ -1 +1 @@ -1.11 +1.12 diff --git a/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb index cc2454ee..558dfeda 100644 --- a/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Noise2Void_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1gmHK1IWDAiHJ0-qkdOY0z0IX2RYY65Gy","timestamp":1602520721702},{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **Noise2Void (2D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 2D dataset. If you are interested in 3D dataset, you should use the Noise2Void 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images (Quality control dataset). These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - Results\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"cbTknRcviyT7"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"h5i5CS2bSmZr"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n3B3meGTbYVi"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.11.1']\n","\n","\n","#@markdown ##Install Noise2Void and dependencies\n","\n","# Here we enable Tensorflow 1.\n","!pip install q keras==2.2.5\n","\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","print('Notebook version: '+Notebook_version[0])\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 100**\n"," \n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be between 64 and the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# create DataGenerator-object.\n","\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","#compatibility to easily change the name of the parameters\n","training_images = Training_source \n","imgs = datagen.load_imgs_from_directory(directory = Training_source)\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","full_model_path = model_path+'/'+model_name+'/'\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels)\n","patch_size = 64#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_N2V2D.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"STDOuNOFsTTJ"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"E4QW-tvYsWhX"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," "]},{"cell_type":"code","metadata":{"id":"-Vy-vV7ssabS","cellView":"form"},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"W6pZg0KVnPzf"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"l-EDcv3Wyvqb","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"keIQhCmOMv5S"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"PXcLuX5jbNUv"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"cellView":"form","id":"rBelu-LtbOTh"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# split patches from the training images\n","Xdata = datagen.generate_patches_from_list(imgs, shape=(patch_size,patch_size), augment=Use_Data_augmentation)\n","shape_of_Xdata = Xdata.shape\n","# create a threshold (10 % patches for the validation)\n","threshold = int(shape_of_Xdata[0]*(percentage_validation/100))\n","# split the patches into training patches and validation patches\n","X = Xdata[threshold:]\n","X_val = Xdata[:threshold]\n","print(Xdata.shape[0],\"patches created.\")\n","print(threshold,\"patch images for validation (\",percentage_validation,\"%).\")\n","print(Xdata.shape[0]-threshold,\"patch images for training.\")\n","%memit\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# create a Config object\n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n"," train_loss='mse', batch_norm=True, train_batch_size=batch_size, n2v_perc_pix=0.198, \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","# Let's look at the parameters stored in the config-object.\n","vars(config)\n"," \n"," \n","# create network model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","print(\"Setup done.\")\n","print(config)\n","\n","\n","# creates a plot and shows one training patch and one validation patch.\n","plt.figure(figsize=(16,87))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (N2V -- N2V Predict). You can find it in your model folder (export.bioimage.io.zip and model.yaml). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"fisJmA13Mv5e","scrolled":true,"cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","\n","history = model.train(X, X_val)\n","print(\"Training done.\")\n","%memit\n","\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='Noise2Void', \n"," description='Noise2Void 2D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='YX',\n"," patch_shape=(patch_size, patch_size))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","from datetime import datetime\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Noise2Void 2D'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and method:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[0]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(Xdata.shape[0])+' image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(Xdata.shape[0])+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(26, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by default.'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
\n","\"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(28, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","# pdf.set_font('')\n","# pdf.set_font('Arial', size = 10, style = 'B')\n","# pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n","# pdf.set_font('')\n","# pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training Image', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_N2V2D.png').shape\n","pdf.image('/content/TrainingDataExample_N2V2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","# if Use_Data_augmentation:\n","# ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n","# pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Vd9igRYvSnTr"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"sTMDT1u7rK9g"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"OVxLyPyPiv85","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WZDvRjLZu-Lm"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","id":"vMzSP50kMv5p"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lreUY7-SsGkI"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"kjbHJHbtsg2R","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model_training.predict(img, axes='YX', n_tiles=(2,1))\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT)\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source)\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","#Make a pdf summary of the QC results\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Noise2Void 2D'\n","\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n","if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n","pdf.ln(2)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DWAhOBc7gpzN"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"KAILvLGFS2-1"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks"]},{"cell_type":"code","metadata":{"id":"bl3EdYFVS7X9","cellView":"form"},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n"," # r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","if Data_type == 1 :\n"," print(\"Single images are now beeing predicted\")\n","\n","# Loop through the files\n"," for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='YX', n_tiles=(2,1))\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='YX') \n","\n"," print(\"Images saved into folder:\", Result_folder)\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," timelapse = imread(os.path.join(r, file))\n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n","\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," prediction_stack[t] = model.predict(img_t, axes='YX', n_tiles=(2,1))\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," imsave(os.path.join(outputdir, base_filename), prediction_stack_32) \n"," \n"," \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PfTw_pQUUAqB"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"jFp-0y4zT_gL","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","\n","\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","if Data_type == 1 :\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x, interpolation='nearest')\n"," plt.title('Input')\n"," plt.axis('off');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y, interpolation='nearest')\n"," plt.title('Predicted output')\n"," plt.axis('off');\n","\n","if Data_type == 2 :\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[1], interpolation='nearest')\n"," plt.title('Input')\n"," plt.axis('off');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[1], interpolation='nearest')\n"," plt.title('Predicted output')\n"," plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wgO7Ok1PBFQj"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS"},"source":["#**Thank you for using Noise2Void 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Noise2Void_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hMjEc-Ex7j-jeYGclaPw2x3OgbkeC6Bl","timestamp":1610626439596},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **Noise2Void (2D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 2D dataset. If you are interested in 3D dataset, you should use the Noise2Void 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images (Quality control dataset). These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - Results\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","#@markdown ##Install Noise2Void and dependencies\n","\n","# Here we enable Tensorflow 1.\n","!pip install q keras==2.2.5\n","\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","from datetime import datetime\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","print('Notebook version: '+Notebook_version[0])\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Noise2Void 2D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[0]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(Xdata.shape[0])+' image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(Xdata.shape[0])+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(26, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by default.'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," # pdf.set_font('')\n"," # pdf.set_font('Arial', size = 10, style = 'B')\n"," # pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," # pdf.set_font('')\n"," # pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training Image', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_N2V2D.png').shape\n"," pdf.image('/content/TrainingDataExample_N2V2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","\n"," #Make a pdf summary of the QC results\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Noise2Void 2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 100**\n"," \n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be between 64 and the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# create DataGenerator-object.\n","\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","#compatibility to easily change the name of the parameters\n","training_images = Training_source \n","imgs = datagen.load_imgs_from_directory(directory = Training_source)\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","full_model_path = model_path+'/'+model_name+'/'\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels)\n","patch_size = 64#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_N2V2D.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," "]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","cellView":"form"},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# split patches from the training images\n","Xdata = datagen.generate_patches_from_list(imgs, shape=(patch_size,patch_size), augment=Use_Data_augmentation)\n","shape_of_Xdata = Xdata.shape\n","# create a threshold (10 % patches for the validation)\n","threshold = int(shape_of_Xdata[0]*(percentage_validation/100))\n","# split the patches into training patches and validation patches\n","X = Xdata[threshold:]\n","X_val = Xdata[:threshold]\n","print(Xdata.shape[0],\"patches created.\")\n","print(threshold,\"patch images for validation (\",percentage_validation,\"%).\")\n","print(Xdata.shape[0]-threshold,\"patch images for training.\")\n","%memit\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# create a Config object\n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n"," train_loss='mse', batch_norm=True, train_batch_size=batch_size, n2v_perc_pix=0.198, \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","# Let's look at the parameters stored in the config-object.\n","vars(config)\n"," \n"," \n","# create network model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","print(\"Setup done.\")\n","print(config)\n","\n","\n","# creates a plot and shows one training patch and one validation patch.\n","plt.figure(figsize=(16,87))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');\n","\n","pdf_export(pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this \n","point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSB Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n"]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","\n","history = model.train(X, X_val)\n","print(\"Training done.\")\n","%memit\n","\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='Noise2Void', \n"," description='Noise2Void 2D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='YX',\n"," patch_shape=(patch_size, patch_size))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","pdf_export(trained = True, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model_training.predict(img, axes='YX', n_tiles=(2,1))\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT)\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source)\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks"]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","cellView":"form"},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n"," # r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","if Data_type == 1 :\n"," print(\"Single images are now beeing predicted\")\n","\n","# Loop through the files\n"," for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='YX', n_tiles=(2,1))\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='YX') \n","\n"," print(\"Images saved into folder:\", Result_folder)\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," timelapse = imread(os.path.join(r, file))\n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n","\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," prediction_stack[t] = model.predict(img_t, axes='YX', n_tiles=(2,1))\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," imsave(os.path.join(outputdir, base_filename), prediction_stack_32) \n"," \n"," \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"67_8rEKp8C-z"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"n-stU-f08Cae"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","\n","\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","if Data_type == 1 :\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x, interpolation='nearest')\n"," plt.title('Input')\n"," plt.axis('off');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y, interpolation='nearest')\n"," plt.title('Predicted output')\n"," plt.axis('off');\n","\n","if Data_type == 2 :\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[1], interpolation='nearest')\n"," plt.title('Input')\n"," plt.axis('off');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[1], interpolation='nearest')\n"," plt.title('Predicted output')\n"," plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using Noise2Void 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb index cd50006b..6a8e71c1 100644 --- a/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Noise2Void_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **Noise2Void (3D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 3D dataset. If you are interested in 2D dataset, you should use the Noise2Void 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","\n","\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - **Results**\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.11']\n","\n","\n","#@markdown ##Install Noise2Void and dependencies\n","!pip install q keras==2.2.5\n","\n","# Enable the Tensorflow 1 instead of the Tensorflow 2.\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** This is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","# Create DataGenerator-object.\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","imgs = datagen.load_imgs_from_directory(directory = Training_source, dims='ZYX')\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of steps and epochs:\n","\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 64#@param {type:\"number\"}\n","\n","patch_height = 4#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","#Load one randomly chosen training target file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING + \"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","\n","#Here we display a single z plane\n","\n","norm = simple_norm(x[mid_plane], percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_N2V3D.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," By default data augmentation is enabled. Disable this option is you run out of RAM during the training.\n"," "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","#Disable some of the warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# Create batches from the training data.\n","patches = datagen.generate_patches_from_list(imgs, shape=(patch_height, patch_size, patch_size), augment=Use_Data_augmentation)\n","\n","# Patches are divited into training and validation patch set. This inhibits over-lapping of patches. \n","number_train_images =int(len(patches)*(percentage_validation/100))\n","X = patches[number_train_images:]\n","X_val = patches[:number_train_images]\n","\n","print(len(patches),\"patches created.\")\n","print(number_train_images,\"patch images for validation (\",percentage_validation,\"%).\")\n","print((len(patches)-number_train_images),\"patch images for training.\")\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size) + 1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# creates Congfig object. \n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps,train_epochs=number_of_epochs, train_loss='mse', batch_norm=True, \n"," train_batch_size=batch_size, n2v_perc_pix=0.198, n2v_patch_shape=(patch_height, patch_size, patch_size), \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","vars(config)\n","\n","# Create the default model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","print(\"Parameters transferred into the model.\")\n","print(config)\n","\n","# Shows a training batch and a validation batch.\n","plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (N2V -- N2V Predict). You can find it in your model folder (export.bioimage.io.zip and model.yaml). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"scrolled":true,"cellView":"form","id":"iwNmp1PUzRDQ"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","# the training starts.\n","history = model.train(X, X_val)\n","%memit\n","print(\"Model training is now done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='Noise2Void', \n"," description='Noise2Void 3D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='ZYX',\n"," patch_shape=(patch_size, patch_size))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBDeep Fiji plugin\")\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","from datetime import datetime\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Noise2Void 3D'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and method:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[0]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(len(patches))+' image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(len(patches))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(26, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by default.'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
\n","\"\"\".format(number_of_epochs,str(patch_height)+'x'+str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","# pdf.set_font('')\n","# pdf.set_font('Arial', size = 10, style = 'B')\n","# pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n","# pdf.set_font('')\n","# pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Training Image', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_N2V3D.png').shape\n","pdf.image('/content/TrainingDataExample_N2V3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","# if Use_Data_augmentation:\n","# ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n","# pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nRaaG02xZh_N"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction,force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource,force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane])\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","#Make a pdf summary of the QC results\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Noise2Void 3D'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n","pdf.ln(3)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," mSSIM_SvsGT = header[3]\n"," NRMSE_PvsGT = header[4]\n"," NRMSE_SvsGT = header[5]\n"," PSNR_PvsGT = header[6]\n"," PSNR_SvsGT = header[7]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," mSSIM_SvsGT = row[3]\n"," NRMSE_PvsGT = row[4]\n"," NRMSE_SvsGT = row[5]\n"," PSNR_PvsGT = row[6]\n"," PSNR_SvsGT = row[7]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#Activate the pretrained model. \n","#model_training = CARE(config=None, name=model_name, basedir=model_path)\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model.\n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Denoising images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX')\n"," \n","print(\"Prediction of images done.\")\n","\n","print(\"One example is displayed here.\")\n","\n","\n","#Display an example\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest')\n","plt.title('Noisy Input (single Z plane)');\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest')\n","plt.title('Prediction (single Z plane)');\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using Noise2Void 3D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Noise2Void_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **Noise2Void (3D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 3D dataset. If you are interested in 2D dataset, you should use the Noise2Void 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","\n","\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - **Results**\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","#@markdown ##Install Noise2Void and dependencies\n","!pip install q keras==2.2.5\n","\n","# Enable the Tensorflow 1 instead of the Tensorflow 2.\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","# Here we install Noise2Void and other required packages\n","!pip install n2v\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","print(\"Noise2Void installed.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Noise2Void 3D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[0]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(len(patches))+' image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(len(patches))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(26, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by default.'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
\n"," \"\"\".format(number_of_epochs,str(patch_height)+'x'+str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," # pdf.set_font('')\n"," # pdf.set_font('Arial', size = 10, style = 'B')\n"," # pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," # pdf.set_font('')\n"," # pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Training Image', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_N2V3D.png').shape\n"," pdf.image('/content/TrainingDataExample_N2V3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Noise2Void 3D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," mSSIM_SvsGT = header[3]\n"," NRMSE_PvsGT = header[4]\n"," NRMSE_SvsGT = header[5]\n"," PSNR_PvsGT = header[6]\n"," PSNR_SvsGT = header[7]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," mSSIM_SvsGT = row[3]\n"," NRMSE_PvsGT = row[4]\n"," NRMSE_SvsGT = row[5]\n"," PSNR_PvsGT = row[6]\n"," PSNR_SvsGT = row[7]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** This is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","# Create DataGenerator-object.\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","imgs = datagen.load_imgs_from_directory(directory = Training_source, dims='ZYX')\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of steps and epochs:\n","\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 64#@param {type:\"number\"}\n","\n","patch_height = 4#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","#Load one randomly chosen training target file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING + \"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","\n","#Here we display a single z plane\n","\n","norm = simple_norm(x[mid_plane], percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_N2V3D.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," By default data augmentation is enabled. Disable this option is you run out of RAM during the training.\n"," "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = True #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","#Disable some of the warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# Create batches from the training data.\n","patches = datagen.generate_patches_from_list(imgs, shape=(patch_height, patch_size, patch_size), augment=Use_Data_augmentation)\n","\n","# Patches are divited into training and validation patch set. This inhibits over-lapping of patches. \n","number_train_images =int(len(patches)*(percentage_validation/100))\n","X = patches[number_train_images:]\n","X_val = patches[:number_train_images]\n","\n","print(len(patches),\"patches created.\")\n","print(number_train_images,\"patch images for validation (\",percentage_validation,\"%).\")\n","print((len(patches)-number_train_images),\"patch images for training.\")\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size) + 1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# creates Congfig object. \n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps,train_epochs=number_of_epochs, train_loss='mse', batch_norm=True, \n"," train_batch_size=batch_size, n2v_perc_pix=0.198, n2v_patch_shape=(patch_height, patch_size, patch_size), \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","vars(config)\n","\n","# Create the default model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","print(\"Parameters transferred into the model.\")\n","print(config)\n","\n","# Shows a training batch and a validation batch.\n","plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');\n","\n","\n","pdf_export(trained = False, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (N2V -- N2V Predict). You can find it in your model folder (export.bioimage.io.zip and model.yaml). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","# the training starts.\n","history = model.train(X, X_val)\n","%memit\n","print(\"Model training is now done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='Noise2Void', \n"," description='Noise2Void 3D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='ZYX',\n"," patch_shape=(patch_size, patch_size))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBDeep Fiji plugin\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained=True, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction,force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource,force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane])\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#Activate the pretrained model. \n","#model_training = CARE(config=None, name=model_name, basedir=model_path)\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model.\n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Denoising images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX')\n"," \n","print(\"Prediction of images done.\")\n","\n","print(\"One example is displayed here.\")\n","\n","\n","#Display an example\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest')\n","plt.title('Noisy Input (single Z plane)');\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest')\n","plt.title('Prediction (single Z plane)');\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using Noise2Void 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb index 4e372efa..dde32525 100644 --- a/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"colab":{"name":"StarDist_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1tqoReM3xwZ_kybPQFodDX8R6KD7LG65r","timestamp":1602672180981},{"file_id":"1WAfQW1Mj3wy1XQZZUfU4DJVS_R_E8Cn3","timestamp":1585665697353},{"file_id":"1PKVyox_mx2rEE3VlMFQtdnVULJFhYPaD","timestamp":1583443864213},{"file_id":"1XSclOkhhHmn-9LQc9k8c3Y6seT1LEi-Y","timestamp":1583264105465},{"file_id":"1VPZYk3MeSVyZVVEmesz10VtujbD4diJk","timestamp":1579481583477},{"file_id":"1ENdOZir1Gytf6JxzyfbjgfxO3_C1dLHK","timestamp":1575415287126},{"file_id":"1G8b4dF2kCs3ePBGZthPUGOyjJpZ2G_Dm","timestamp":1575379725785},{"file_id":"1P0tT0RR_b3SFKvOcON_MzcAIcxRUQK5B","timestamp":1575377313115},{"file_id":"1hQz8PyJzBRkBZc9NwxM9mU9azRSvghBk","timestamp":1574783624098},{"file_id":"14mWTNjHgIbuuWAxb-0lhmhdIvMoZgrI0","timestamp":1574099686195},{"file_id":"1IWvFuBb0gqaJcUXhhfbcTWNh9cZEXW4S","timestamp":1573647131082},{"file_id":"1hFulBwI57YU6GoVc8sBt5KNIkCS7ynQ3","timestamp":1573579952409},{"file_id":"1Ba_Bu-PXN_2Mq5W6YHMgUYsJEfgbPtS-","timestamp":1573035984524},{"file_id":"1ePC44Qq_C2hSFGPM3PKyb0J6UBXSPddp","timestamp":1573032545399},{"file_id":"https://github.com/mpicbg-csbd/stardist/blob/master/examples/2D/2_training.ipynb","timestamp":1572984225873}],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kiFRRolPa-Rb"},"source":["# **StarDist (2D)**\n","---\n","\n","**StarDist 2D** is a deep-learning method that can be used to segment cell nuclei from bioimages and was first published by [Schmidt *et al.* in 2018, on arXiv](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 2D network is based on an adapted U-Net network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist 3D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"iSuNqQ2ZMVGM"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"4-oByBSdE6DE"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"t1sYuLChbRV3"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CDxBu1-19OyC"},"source":["\n","\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"4waLStm0RPFo","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZLY4qhgj8w-R"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"Ukil4yuS8seC","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bB0IaQMZmWYM"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"j0w7C8P5zPIp","cellView":"form"},"source":["\n","Notebook_version = ['1.11']\n","\n","\n","#@markdown ##Install StarDist and dependencies\n","%tensorflow_version 1.x\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools # improves STARDIST performances\n","!pip install edt # improves STARDIST performances\n","!pip install wget\n","!pip install fpdf\n","!pip install PTable # Nice tables \n","\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available, relabel_image_stardist, random_label_cmap, relabel_image_stardist, _draw_polygons, export_imagej_rois\n","from stardist.models import Config2D, StarDist2D, StarDistData2D # import objects\n","from stardist.matching import matching_dataset\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32, img_as_ubyte, img_as_float\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","import cv2\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print('------------------------------------------')\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DPWhXaltAYgH"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"KWpu5p8utpE2"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"HJKFAmuXc6d1"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 50-100 epochs, but a full training should run for up to 400 epochs. Evaluate the performance after training (see 5.). **Default value: 100**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 2**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** Input the size of the patches use to train StarDist 2D (length of a side). The value should be smaller or equal to the dimensions of the image. Make the patch size as large as possible and divisible by 8. **Default value: dimension of the training images** \n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance, a square has 4 corners). **Default value: 32** \n","\n","**`grid_parameter`:** increase this number if the cells/nuclei are very large or decrease it if they are very small. **Default value: 2**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size value until the OOM error disappear.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"CNJImzzVnr7h","cellView":"form"},"source":["#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","#trained_model = model_path \n","\n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 2 #@param {type:\"number\"}\n","number_of_steps = 20#@param {type:\"number\"}\n","patch_size = 96 #@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","n_rays = 32 #@param {type:\"number\"}\n","grid_parameter = 2#@param [1, 2, 4, 8, 16, 32] {type:\"raw\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 2\n"," n_rays = 32\n"," percentage_validation = 10\n"," grid_parameter = 2\n"," initial_learning_rate = 0.0003\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n","\n"," \n","# Here we open will randomly chosen input and output image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check the image dimensions\n","\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","print('Loaded images (width, length) =', x.shape)\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters):\n"," patch_size = min(Image_Y, Image_X)\n"," \n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","if patch_size > 2048:\n"," patch_size = 2048\n"," print (bcolors.WARNING + \" Your image dimension is large; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","# Here we check that the patch_size is divisible by 16\n","if not patch_size % 16 == 0:\n"," patch_size = ((int(patch_size / 16)-1) * 16)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', cmap=lbl_cmap)\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_StarDist2D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vgT0NU3P6Bwt"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"8in3wzAw6G6g"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here via random rotations, flips, and intensity changes.\n","\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** "]},{"cell_type":"code","metadata":{"id":"2zk1H8J06aJH","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:10, step:1}\n","\n","\n","def random_fliprot(img, mask): \n"," assert img.ndim >= mask.ndim\n"," axes = tuple(range(mask.ndim))\n"," perm = tuple(np.random.permutation(axes))\n"," img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) \n"," mask = mask.transpose(perm) \n"," for ax in axes: \n"," if np.random.rand() > 0.5:\n"," img = np.flip(img, axis=ax)\n"," mask = np.flip(mask, axis=ax)\n"," return img, mask \n","\n","def random_intensity_change(img):\n"," img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)\n"," return img\n","\n","\n","def augmenter(x, y):\n"," \"\"\"Augmentation of a single input/label image pair.\n"," x is an input image\n"," y is the corresponding ground-truth label image\n"," \"\"\"\n"," x, y = random_fliprot(x, y)\n"," x = random_intensity_change(x)\n"," # add some gaussian noise\n"," sig = 0.02*np.random.uniform(0,1)\n"," x = x + sig*np.random.normal(0,1,x.shape)\n"," return x, y\n","\n","\n","\n","if Use_Data_augmentation:\n"," augmenter = augmenter\n"," print(\"Data augmentation enabled\")\n","\n","\n","if not Use_Data_augmentation:\n"," augmenter = None\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"x4zMG4lMths-"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"SfQeukJJtv9u","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"2D_versatile_fluo_from_Stardist_Fiji\" #@param [\"Model_from_file\", \"2D_versatile_fluo_from_Stardist_Fiji\", \"2D_Demo_Model_from_Stardist_Github\", \"Versatile_H&E_nuclei\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 2D model provided in the Stardist 2D github ------------------------\n","\n"," if pretrained_model_choice == \"2D_Demo_Model_from_Stardist_Github\":\n"," pretrained_model_name = \"2D_Demo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_Github\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_best.h5?raw=true\", pretrained_model_path) \n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the Demo 2D_versatile_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"2D_versatile_fluo_from_Stardist_Fiji\":\n"," print(\"Downloading the 2D_versatile_fluo_from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_fluo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_fluo.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_fluo.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","# --------------------- Download the Versatile (H&E nuclei)_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"Versatile_H&E_nuclei\":\n"," print(\"Downloading the Versatile_H&E_nuclei from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_he\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_he.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_he.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist' + W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DECuc3HZDbwG"},"source":["#**4. Train the network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"NwV5LweiavgQ"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"uTM781rCKT8r","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","\n","Training_source_dir = Training_source\n","Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","\n","#Normalize images and fill small label holes.\n","\n","if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","\n","if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","#It is advisable to use 10 % of your training dataset for validation. This ensures the truthfull validation error value. If only few validation images are used network may choose too easy or too challenging images for validation. \n","# split training data (images and masks) into training images and validation images.\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","#Here we ensure that our network has a minimal number of steps\n","\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= (int(len(X)/batch_size)+1)\n","\n","if (Use_Data_augmentation):\n"," augmentation_factor = Multiply_dataset_by\n"," number_of_steps = number_of_steps * augmentation_factor\n","\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n","conf = Config2D (\n"," n_rays = n_rays,\n"," use_gpu = use_gpu,\n"," train_batch_size = batch_size,\n"," n_channel_in = n_channel,\n"," train_patch_size = (patch_size, patch_size),\n"," grid = (grid_parameter, grid_parameter),\n"," train_learning_rate = initial_learning_rate,\n",")\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist2D(conf, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","\n","\n","# --------------------- ---------------------- ------------------------\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(list(Y), np.median)\n","fov = np.array(model._axes_tile_overlap('YX'))\n","if any(median_size > fov):\n"," print(bcolors.WARNING+\"WARNING: median object size larger than field of view of the neural network.\")\n","print(conf)\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nnMCvu2PKT9W"},"source":["\n","## **4.2. Start Training**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the Stardist Fiji plugin. You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"XfCF-Q4lKT9e","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","\n","\n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","print(\"Network optimization in progress\")\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","\n","print(\"Done\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the Stardist Fiji plugin\")\n","\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'StarDist 2D'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and method:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","#text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)\n"," \n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
n_rays{5}
grid_parameter{6}
initial_learning_rate{7}
\n","\"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,n_rays,grid_parameter,initial_learning_rate)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_StarDist2D.png').shape\n","pdf.image('/content/TrainingDataExample_StarDist2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- StarDist 2D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","if Use_Data_augmentation:\n"," ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_4, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iYRrmh0dCrNs"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder.\n","\n"]},{"cell_type":"markdown","metadata":{"id":"U8H7QRfKBzI8"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"o2O0QnO4PFlz","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-2b4RMU_Ec2y"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"KG8wZrA3Ef4n","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"GFJBwr5TEgcq"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** (IuO) metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","Here, the IuO is both calculated over the whole image and on a per-object basis. The value displayed below is the IuO value calculated over the entire image. The IuO value calculated on a per-object basis is used to calculate the other metrics displayed.\n","\n","“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. \n","\n","When a segmented object has an IuO value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.\n","\n","The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).\n","\n","For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.\n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"G5pG-hFFa1av","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","from stardist.matching import matching\n","from stardist.plot import render_label, render_label_pred \n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n"," \n","#Normalize images.\n","\n","if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n","\n","model = StarDist2D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file, delimiter=\",\")\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\", \"false positive\", \"true positive\", \"false negative\", \"precision\", \"recall\", \"accuracy\", \"f1 score\", \"n_true\", \"n_pred\", \"mean_true_score\", \"mean_matched_score\", \"panoptic_quality\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," # Calculate the matching (with IoU threshold `thresh`) and all metrics\n","\n"," stats = matching(test_prediction, test_ground_truth_image, thresh=0.5)\n"," \n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])\n","\n","from tabulate import tabulate\n","\n","df = pd.read_csv (QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\")\n","print(tabulate(df, headers='keys', tablefmt='psql'))\n","\n","\n","from astropy.visualization import simple_norm\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file = os.listdir(Source_QC_folder)):\n"," \n","\n"," plt.figure(figsize=(25,5))\n"," if n_channel > 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file))\n"," if n_channel == 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n","\n"," target_image = io.imread(os.path.join(Target_QC_folder, file), as_gray = True)\n"," prediction = io.imread(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\"+file, as_gray = True)\n","\n"," stats = matching(prediction, target_image, thresh=0.5)\n","\n"," target_image_mask = np.empty_like(target_image)\n"," target_image_mask[target_image > 0] = 255\n"," target_image_mask[target_image == 0] = 0\n"," \n"," prediction_mask = np.empty_like(prediction)\n"," prediction_mask[prediction > 0] = 255\n"," prediction_mask[prediction == 0] = 0\n","\n"," intersection = np.logical_and(target_image_mask, prediction_mask)\n"," union = np.logical_or(target_image_mask, prediction_mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," norm = simple_norm(source_image, percent = 99)\n","\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," if n_channel > 1:\n"," plt.imshow(source_image)\n"," if n_channel == 1:\n"," plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, cmap='Greens')\n"," plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3 )));\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Stardist 2D'\n","\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n","if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n","pdf.ln(2)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(full_QC_model_path+'/Quality Control/Quality_Control for '+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," #image = header[0]\n"," #PvGT_IoU = header[1]\n"," fp = header[2]\n"," tp = header[3]\n"," fn = header[4]\n"," precision = header[5]\n"," recall = header[6]\n"," acc = header[7]\n"," f1 = header[8]\n"," n_true = header[9]\n"," n_pred = header[10]\n"," mean_true = header[11]\n"," mean_matched = header[12]\n"," panoptic = header[13]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(\"image #\",\"Prediction v. GT IoU\",'false pos.','true pos.','false neg.',precision,recall,acc,f1,n_true,n_pred,mean_true,mean_matched,panoptic)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," #image = row[0]\n"," PvGT_IoU = row[1]\n"," fp = row[2]\n"," tp = row[3]\n"," fn = row[4]\n"," precision = row[5]\n"," recall = row[6]\n"," acc = row[7]\n"," f1 = row[8]\n"," n_true = row[9]\n"," n_pred = row[10]\n"," mean_true = row[11]\n"," mean_matched = row[12]\n"," panoptic = row[13]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(str(i),str(round(float(PvGT_IoU),3)),fp,tp,fn,str(round(float(precision),3)),str(round(float(recall),3)),str(round(float(acc),3)),str(round(float(f1),3)),n_true,n_pred,str(round(float(mean_true),3)),str(round(float(mean_matched),3)),str(round(float(panoptic),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- StarDist 2D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iAPmwlxCEzxQ"},"source":["# **6. Using the trained model**\n","---"]},{"cell_type":"markdown","metadata":{"id":"btXwwnVpBEMB"},"source":["\n","\n","## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.\n","\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n","In stardist the following results can be exported:\n","- Region of interest (ROI) that can be opened in ImageJ / Fiji. The ROI are saved inside of a .zip file in your choosen result folder. To open the ROI in Fiji, just drag and drop the zip file !**\n","- The predicted mask images\n","- A tracking file that can easily be imported into Trackmate to track the nuclei (Stacks only).\n","- A CSV file that contains the number of nuclei detected per image (single image only). \n","- A CSV file that contains the coordinate the centre of each detected nuclei (single image only). \n","\n"]},{"cell_type":"code","metadata":{"id":"x8UXP8S2eoo_","cellView":"form"},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","#@markdown ###What outputs would you like to generate?\n","Region_of_interests = True #@param {type:\"boolean\"}\n","Mask_images = True #@param {type:\"boolean\"}\n","Tracking_file = False #@param {type:\"boolean\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#single images\n","Data_folder = Data_folder+\"/*.tif\"\n","\n","if Data_type == 1 :\n"," print(\"Single images are now beeing predicted\")\n"," np.random.seed(16)\n"," lbl_cmap = random_label_cmap()\n"," X = sorted(glob(Data_folder))\n"," X = list(map(imread,X))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","\n"," if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n"," sys.stdout.flush()\n"," \n"," \n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n"," \n"," Nuclei_number = []\n","\n"," # modify the names to suitable form: path_images/image_numberX.tif\n"," FILEnames = []\n"," for m in names:\n"," m = Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension=[]\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n"," \n","\n"," # Save all ROIs and masks into results folder\n"," \n"," for i in range(len(X)):\n"," img = normalize(X[i], 1,99.8, axis = axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," \n"," os.chdir(Results_folder)\n","\n"," if Mask_images:\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," if Region_of_interests:\n"," export_imagej_rois(name_no_extension[i], polygons['coord'])\n","\n"," if Tracking_file:\n"," print(bcolors.WARNING+\"Tracking files are only generated when stacks are predicted\"+W) \n"," \n"," Nuclei_centre_coordinate = polygons['points']\n"," my_df2 = pd.DataFrame(Nuclei_centre_coordinate)\n"," my_df2.columns =['Y', 'X']\n"," \n"," my_df2.to_csv(Results_folder+'/'+name_no_extension[i]+'_Nuclei_centre.csv', index=False, header=True)\n","\n","\n"," Nuclei_array = polygons['coord']\n"," Nuclei_array2 = [names[i], Nuclei_array.shape[0]]\n"," Nuclei_number.append(Nuclei_array2) \n","\n"," my_df = pd.DataFrame(Nuclei_number)\n"," my_df.to_csv(Results_folder+'/Nuclei_count.csv', index=False, header=False)\n"," \n","\n"," # One example is displayed\n","\n"," print(\"One example image is displayed bellow:\")\n"," plt.figure(figsize=(10,10))\n"," plt.imshow(img if img.ndim==2 else img[...,:3], clim=(0,1), cmap='gray')\n"," plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)\n"," plt.axis('off');\n"," plt.savefig(name_no_extension[i]+\"_overlay.tif\")\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," np.random.seed(42)\n"," lbl_cmap = random_label_cmap()\n"," Y = sorted(glob(Data_folder))\n"," X = list(map(imread,Y))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," #Load a pretrained network\n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension = []\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n","\n"," outputdir = Path(Results_folder)\n","\n","# Save all ROIs and images in Results folder.\n"," for num, i in enumerate(X):\n"," print(\"Performing prediction on: \"+names[num])\n","\n"," \n"," timelapse = np.stack(i)\n"," timelapse = normalize(timelapse, 1,99.8, axis=(0,)+tuple(1+np.array(axis_norm)))\n"," timelapse.shape\n","\n"," if Region_of_interests: \n"," polygons = [model.predict_instances(frame)[1]['coord'] for frame in tqdm(timelapse)] \n"," export_imagej_rois(os.path.join(outputdir, name_no_extension[num]), polygons) \n"," \n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n"," Tracking_stack = np.zeros((n_timepoint, timelapse.shape[2], timelapse.shape[1]))\n","\n","# Save the masks in the result folder\n"," if Mask_images or Tracking_file:\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," labels, polygons = model.predict_instances(img_t) \n"," prediction_stack[t] = labels\n","\n","# Create a tracking file for trackmate\n","\n"," for point in polygons['points']:\n"," cv2.circle(Tracking_stack[t],tuple(point),0,(1), -1)\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," Tracking_stack_32 = img_as_float32(Tracking_stack, force_copy=False)\n"," Tracking_stack_8 = img_as_ubyte(Tracking_stack_32, force_copy=True)\n"," \n"," Tracking_stack_8_rot = np.rot90(Tracking_stack_8, axes=(1,2))\n"," Tracking_stack_8_rot_flip = np.fliplr(Tracking_stack_8_rot)\n","\n"," os.chdir(Results_folder)\n"," if Mask_images:\n"," imsave(names[num], prediction_stack_32)\n"," if Tracking_file:\n"," imsave(name_no_extension[num]+\"_tracking_file.tif\", Tracking_stack_8_rot_flip)\n","\n"," \n","\n","print(\"Predictions completed\") "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SxJsrw3kTcFx"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"rH_J20ydXWRQ"},"source":["\n","#**Thank you for using StarDist 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"StarDist_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1610969691998},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **StarDist (2D)**\n","---\n","\n","**StarDist 2D** is a deep-learning method that can be used to segment cell nuclei from bioimages and was first published by [Schmidt *et al.* in 2018, on arXiv](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 2D network is based on an adapted U-Net network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist 3D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["\n","Notebook_version = ['1.12']\n","\n","\n","#@markdown ##Install StarDist and dependencies\n","%tensorflow_version 1.x\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools # improves STARDIST performances\n","!pip install edt # improves STARDIST performances\n","!pip install wget\n","!pip install fpdf\n","!pip install PTable # Nice tables \n","\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available, relabel_image_stardist, random_label_cmap, relabel_image_stardist, _draw_polygons, export_imagej_rois\n","from stardist.models import Config2D, StarDist2D, StarDistData2D # import objects\n","from stardist.matching import matching_dataset\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32, img_as_ubyte, img_as_float\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","import cv2\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print('------------------------------------------')\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained=False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'StarDist 2D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)\n"," \n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
n_rays{5}
grid_parameter{6}
initial_learning_rate{7}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,n_rays,grid_parameter,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_StarDist2D.png').shape\n"," pdf.image('/content/TrainingDataExample_StarDist2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- StarDist 2D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," if augmentation:\n"," ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_4, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Stardist 2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/Quality_Control for '+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," #image = header[0]\n"," #PvGT_IoU = header[1]\n"," fp = header[2]\n"," tp = header[3]\n"," fn = header[4]\n"," precision = header[5]\n"," recall = header[6]\n"," acc = header[7]\n"," f1 = header[8]\n"," n_true = header[9]\n"," n_pred = header[10]\n"," mean_true = header[11]\n"," mean_matched = header[12]\n"," panoptic = header[13]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(\"image #\",\"Prediction v. GT IoU\",'false pos.','true pos.','false neg.',precision,recall,acc,f1,n_true,n_pred,mean_true,mean_matched,panoptic)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," #image = row[0]\n"," PvGT_IoU = row[1]\n"," fp = row[2]\n"," tp = row[3]\n"," fn = row[4]\n"," precision = row[5]\n"," recall = row[6]\n"," acc = row[7]\n"," f1 = row[8]\n"," n_true = row[9]\n"," n_pred = row[10]\n"," mean_true = row[11]\n"," mean_matched = row[12]\n"," panoptic = row[13]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(str(i),str(round(float(PvGT_IoU),3)),fp,tp,fn,str(round(float(precision),3)),str(round(float(recall),3)),str(round(float(acc),3)),str(round(float(f1),3)),n_true,n_pred,str(round(float(mean_true),3)),str(round(float(mean_matched),3)),str(round(float(panoptic),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- StarDist 2D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 50-100 epochs, but a full training should run for up to 400 epochs. Evaluate the performance after training (see 5.). **Default value: 100**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 2**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** Input the size of the patches use to train StarDist 2D (length of a side). The value should be smaller or equal to the dimensions of the image. Make the patch size as large as possible and divisible by 8. **Default value: dimension of the training images** \n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance, a square has 4 corners). **Default value: 32** \n","\n","**`grid_parameter`:** increase this number if the cells/nuclei are very large or decrease it if they are very small. **Default value: 2**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size value until the OOM error disappear.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","#trained_model = model_path \n","\n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 4 #@param {type:\"number\"}\n","number_of_steps = 0#@param {type:\"number\"}\n","patch_size = 512 #@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","n_rays = 32 #@param {type:\"number\"}\n","grid_parameter = 2#@param [1, 2, 4, 8, 16, 32] {type:\"raw\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 2\n"," n_rays = 32\n"," percentage_validation = 10\n"," grid_parameter = 2\n"," initial_learning_rate = 0.0003\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n","\n"," \n","# Here we open will randomly chosen input and output image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check the image dimensions\n","\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","print('Loaded images (width, length) =', x.shape)\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters):\n"," patch_size = min(Image_Y, Image_X)\n"," \n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print(bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","if patch_size > 2048:\n"," patch_size = 2048\n"," print(bcolors.WARNING + \" Your image dimension is large; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","# Here we check that the patch_size is divisible by 16\n","if not patch_size % 16 == 0:\n"," patch_size = ((int(patch_size / 16)-1) * 16)\n"," print(bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', cmap=lbl_cmap)\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_StarDist2D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here via random rotations, flips, and intensity changes.\n","\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:10, step:1}\n","\n","\n","def random_fliprot(img, mask): \n"," assert img.ndim >= mask.ndim\n"," axes = tuple(range(mask.ndim))\n"," perm = tuple(np.random.permutation(axes))\n"," img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) \n"," mask = mask.transpose(perm) \n"," for ax in axes: \n"," if np.random.rand() > 0.5:\n"," img = np.flip(img, axis=ax)\n"," mask = np.flip(mask, axis=ax)\n"," return img, mask \n","\n","def random_intensity_change(img):\n"," img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)\n"," return img\n","\n","\n","def augmenter(x, y):\n"," \"\"\"Augmentation of a single input/label image pair.\n"," x is an input image\n"," y is the corresponding ground-truth label image\n"," \"\"\"\n"," x, y = random_fliprot(x, y)\n"," x = random_intensity_change(x)\n"," # add some gaussian noise\n"," sig = 0.02*np.random.uniform(0,1)\n"," x = x + sig*np.random.normal(0,1,x.shape)\n"," return x, y\n","\n","\n","\n","if Use_Data_augmentation:\n"," augmenter = augmenter\n"," print(\"Data augmentation enabled\")\n","\n","\n","if not Use_Data_augmentation:\n"," augmenter = None\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"2D_versatile_fluo_from_Stardist_Fiji\" #@param [\"Model_from_file\", \"2D_versatile_fluo_from_Stardist_Fiji\", \"2D_Demo_Model_from_Stardist_Github\", \"Versatile_H&E_nuclei\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 2D model provided in the Stardist 2D github ------------------------\n","\n"," if pretrained_model_choice == \"2D_Demo_Model_from_Stardist_Github\":\n"," pretrained_model_name = \"2D_Demo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_Github\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_best.h5?raw=true\", pretrained_model_path) \n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the Demo 2D_versatile_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"2D_versatile_fluo_from_Stardist_Fiji\":\n"," print(\"Downloading the 2D_versatile_fluo_from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_fluo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_fluo.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_fluo.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","# --------------------- Download the Versatile (H&E nuclei)_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"Versatile_H&E_nuclei\":\n"," print(\"Downloading the Versatile_H&E_nuclei from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_he\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_he.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_he.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist' + W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","\n","Training_source_dir = Training_source\n","Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","\n","#Normalize images and fill small label holes.\n","\n","if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","\n","if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","#It is advisable to use 10 % of your training dataset for validation. This ensures the truthfull validation error value. If only few validation images are used network may choose too easy or too challenging images for validation. \n","# split training data (images and masks) into training images and validation images.\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","# Currently always false for stability\n","use_gpu = False and gputools_available()\n","\n","#Here we ensure that our network has a minimal number of steps\n","\n","if (Use_Default_Advanced_Parameters) or (number_of_steps == 0): \n"," # number_of_steps= (int(len(X)/batch_size)+1)\n"," number_of_steps = Image_X*Image_Y/(patch_size*patch_size)*(int(len(X)/batch_size)+1)\n"," if (Use_Data_augmentation):\n"," augmentation_factor = Multiply_dataset_by\n"," number_of_steps = number_of_steps * augmentation_factor\n","\n","\n","\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n","conf = Config2D (\n"," n_rays = n_rays,\n"," use_gpu = use_gpu,\n"," train_batch_size = batch_size,\n"," n_channel_in = n_channel,\n"," train_patch_size = (patch_size, patch_size),\n"," grid = (grid_parameter, grid_parameter),\n"," train_learning_rate = initial_learning_rate,\n",")\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist2D(conf, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","\n","\n","# --------------------- ---------------------- ------------------------\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(list(Y), np.median)\n","fov = np.array(model._axes_tile_overlap('YX'))\n","if any(median_size > fov):\n"," print(bcolors.WARNING+\"WARNING: median object size larger than field of view of the neural network.\")\n","print(conf)\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the Stardist Fiji plugin. You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","\n","\n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","print(\"Network optimization in progress\")\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","\n","print(\"Done\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the Stardist Fiji plugin\")\n","\n","pdf_export(trained=True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","#Create a pdf document with training summary"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** (IuO) metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","Here, the IuO is both calculated over the whole image and on a per-object basis. The value displayed below is the IuO value calculated over the entire image. The IuO value calculated on a per-object basis is used to calculate the other metrics displayed.\n","\n","“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. \n","\n","When a segmented object has an IuO value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.\n","\n","The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).\n","\n","For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.\n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","from stardist.matching import matching\n","from stardist.plot import render_label, render_label_pred \n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n"," \n","#Normalize images.\n","\n","if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n","\n","model = StarDist2D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file, delimiter=\",\")\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\", \"false positive\", \"true positive\", \"false negative\", \"precision\", \"recall\", \"accuracy\", \"f1 score\", \"n_true\", \"n_pred\", \"mean_true_score\", \"mean_matched_score\", \"panoptic_quality\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," # Calculate the matching (with IoU threshold `thresh`) and all metrics\n","\n"," stats = matching(test_prediction, test_ground_truth_image, thresh=0.5)\n"," \n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])\n","\n","from tabulate import tabulate\n","\n","df = pd.read_csv (QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\")\n","print(tabulate(df, headers='keys', tablefmt='psql'))\n","\n","\n","from astropy.visualization import simple_norm\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file = os.listdir(Source_QC_folder)):\n"," \n","\n"," plt.figure(figsize=(25,5))\n"," if n_channel > 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file))\n"," if n_channel == 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n","\n"," target_image = io.imread(os.path.join(Target_QC_folder, file), as_gray = True)\n"," prediction = io.imread(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\"+file, as_gray = True)\n","\n"," stats = matching(prediction, target_image, thresh=0.5)\n","\n"," target_image_mask = np.empty_like(target_image)\n"," target_image_mask[target_image > 0] = 255\n"," target_image_mask[target_image == 0] = 0\n"," \n"," prediction_mask = np.empty_like(prediction)\n"," prediction_mask[prediction > 0] = 255\n"," prediction_mask[prediction == 0] = 0\n","\n"," intersection = np.logical_and(target_image_mask, prediction_mask)\n"," union = np.logical_or(target_image_mask, prediction_mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," norm = simple_norm(source_image, percent = 99)\n","\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," if n_channel > 1:\n"," plt.imshow(source_image)\n"," if n_channel == 1:\n"," plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, cmap='Greens')\n"," plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3 )));\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["\n","\n","## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.\n","\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n","In stardist the following results can be exported:\n","- Region of interest (ROI) that can be opened in ImageJ / Fiji. The ROI are saved inside of a .zip file in your choosen result folder. To open the ROI in Fiji, just drag and drop the zip file !**\n","- The predicted mask images\n","- A tracking file that can easily be imported into Trackmate to track the nuclei (Stacks only).\n","- A CSV file that contains the number of nuclei detected per image (single image only). \n","- A CSV file that contains the coordinate the centre of each detected nuclei (single image only). \n","\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","#@markdown ###What outputs would you like to generate?\n","Region_of_interests = True #@param {type:\"boolean\"}\n","Mask_images = True #@param {type:\"boolean\"}\n","Tracking_file = False #@param {type:\"boolean\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#single images\n","Data_folder = Data_folder+\"/*.tif\"\n","\n","if Data_type == 1 :\n"," print(\"Single images are now beeing predicted\")\n"," np.random.seed(16)\n"," lbl_cmap = random_label_cmap()\n"," X = sorted(glob(Data_folder))\n"," X = list(map(imread,X))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","\n"," if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n"," sys.stdout.flush()\n"," \n"," \n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n"," \n"," Nuclei_number = []\n","\n"," # modify the names to suitable form: path_images/image_numberX.tif\n"," FILEnames = []\n"," for m in names:\n"," m = Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension=[]\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n"," \n","\n"," # Save all ROIs and masks into results folder\n"," \n"," for i in range(len(X)):\n"," img = normalize(X[i], 1,99.8, axis = axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," \n"," os.chdir(Results_folder)\n","\n"," if Mask_images:\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," if Region_of_interests:\n"," export_imagej_rois(name_no_extension[i], polygons['coord'])\n","\n"," if Tracking_file:\n"," print(bcolors.WARNING+\"Tracking files are only generated when stacks are predicted\"+W) \n"," \n"," Nuclei_centre_coordinate = polygons['points']\n"," my_df2 = pd.DataFrame(Nuclei_centre_coordinate)\n"," my_df2.columns =['Y', 'X']\n"," \n"," my_df2.to_csv(Results_folder+'/'+name_no_extension[i]+'_Nuclei_centre.csv', index=False, header=True)\n","\n","\n"," Nuclei_array = polygons['coord']\n"," Nuclei_array2 = [names[i], Nuclei_array.shape[0]]\n"," Nuclei_number.append(Nuclei_array2) \n","\n"," my_df = pd.DataFrame(Nuclei_number)\n"," my_df.to_csv(Results_folder+'/Nuclei_count.csv', index=False, header=False)\n"," \n","\n"," # One example is displayed\n","\n"," print(\"One example image is displayed bellow:\")\n"," plt.figure(figsize=(10,10))\n"," plt.imshow(img if img.ndim==2 else img[...,:3], clim=(0,1), cmap='gray')\n"," plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)\n"," plt.axis('off');\n"," plt.savefig(name_no_extension[i]+\"_overlay.tif\")\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," np.random.seed(42)\n"," lbl_cmap = random_label_cmap()\n"," Y = sorted(glob(Data_folder))\n"," X = list(map(imread,Y))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," axis_norm = (0,1) # normalize channels independently\n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n"," #Load a pretrained network\n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))]\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension = []\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n","\n"," outputdir = Path(Results_folder)\n","\n","# Save all ROIs and images in Results folder.\n"," for num, i in enumerate(X):\n"," print(\"Performing prediction on: \"+names[num])\n","\n"," \n"," timelapse = np.stack(i)\n"," timelapse = normalize(timelapse, 1,99.8, axis=(0,)+tuple(1+np.array(axis_norm)))\n"," timelapse.shape\n","\n"," if Region_of_interests: \n"," polygons = [model.predict_instances(frame)[1]['coord'] for frame in tqdm(timelapse)] \n"," export_imagej_rois(os.path.join(outputdir, name_no_extension[num]), polygons) \n"," \n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n"," Tracking_stack = np.zeros((n_timepoint, timelapse.shape[2], timelapse.shape[1]))\n","\n","# Save the masks in the result folder\n"," if Mask_images or Tracking_file:\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," labels, polygons = model.predict_instances(img_t) \n"," prediction_stack[t] = labels\n","\n","# Create a tracking file for trackmate\n","\n"," for point in polygons['points']:\n"," cv2.circle(Tracking_stack[t],tuple(point),0,(1), -1)\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," Tracking_stack_32 = img_as_float32(Tracking_stack, force_copy=False)\n"," Tracking_stack_8 = img_as_ubyte(Tracking_stack_32, force_copy=True)\n"," \n"," Tracking_stack_8_rot = np.rot90(Tracking_stack_8, axes=(1,2))\n"," Tracking_stack_8_rot_flip = np.fliplr(Tracking_stack_8_rot)\n","\n"," os.chdir(Results_folder)\n"," if Mask_images:\n"," imsave(names[num], prediction_stack_32)\n"," if Tracking_file:\n"," imsave(name_no_extension[num]+\"_tracking_file.tif\", Tracking_stack_8_rot_flip)\n","\n"," \n","\n","print(\"Predictions completed\") "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\r\n","#**Thank you for using StarDist 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb index 9f89768b..5b30421d 100644 --- a/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"colab":{"name":"StarDist_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"15zLlQlrxpv-lw8NrQhqnviWNxa1yqYAx","timestamp":1603114400753},{"file_id":"1Ur-4VIQ6gf4ONupD6hK0M-AcJkoTzMlU","timestamp":1586789439593},{"file_id":"1PKVyox_mx2rEE3VlMFQtdnVULJFhYPaD","timestamp":1583443864213},{"file_id":"1XSclOkhhHmn-9LQc9k8c3Y6seT1LEi-Y","timestamp":1583264105465},{"file_id":"1VPZYk3MeSVyZVVEmesz10VtujbD4diJk","timestamp":1579481583477},{"file_id":"1ENdOZir1Gytf6JxzyfbjgfxO3_C1dLHK","timestamp":1575415287126},{"file_id":"1G8b4dF2kCs3ePBGZthPUGOyjJpZ2G_Dm","timestamp":1575379725785},{"file_id":"1P0tT0RR_b3SFKvOcON_MzcAIcxRUQK5B","timestamp":1575377313115},{"file_id":"1hQz8PyJzBRkBZc9NwxM9mU9azRSvghBk","timestamp":1574783624098},{"file_id":"14mWTNjHgIbuuWAxb-0lhmhdIvMoZgrI0","timestamp":1574099686195},{"file_id":"1IWvFuBb0gqaJcUXhhfbcTWNh9cZEXW4S","timestamp":1573647131082},{"file_id":"1hFulBwI57YU6GoVc8sBt5KNIkCS7ynQ3","timestamp":1573579952409},{"file_id":"1Ba_Bu-PXN_2Mq5W6YHMgUYsJEfgbPtS-","timestamp":1573035984524},{"file_id":"1ePC44Qq_C2hSFGPM3PKyb0J6UBXSPddp","timestamp":1573032545399},{"file_id":"https://github.com/mpicbg-csbd/stardist/blob/master/examples/2D/2_training.ipynb","timestamp":1572984225873}],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kiFRRolPa-Rb"},"source":["# **StarDist (3D)**\n","---\n","\n","**StarDist 3D** is a deep-learning method that can be used to segment cell nuclei from 3D bioimages and was first published by [Weigert *et al.* in 2019 on arXiv](https://arxiv.org/abs/1908.03636), extending to 3D the 2D appraoch from [Schmidt *et al.* in 2018](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 3D network is based on an adapted ResNet network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 3D dataset. If you are interested in 2D dataset, you should use the StarDist 2D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"iSuNqQ2ZMVGM"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"4-oByBSdE6DE"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - **Masks** \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"t1sYuLChbRV3"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CDxBu1-19OyC"},"source":["\n","\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"4waLStm0RPFo","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZLY4qhgj8w-R"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"Ukil4yuS8seC","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bB0IaQMZmWYM"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"j0w7C8P5zPIp","cellView":"form"},"source":["Notebook_version = ['1.11']\n","\n","\n","\n","#@markdown ##Install StarDist and dependencies\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools\n","!pip install edt\n","!pip install wget\n","!pip install fpdf\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available\n","from stardist.models import Config3D, StarDist3D, StarDistData3D\n","from stardist import relabel_image_stardist3D, Rays_GoldenSpiral, calculate_extents\n","from stardist.matching import matching_dataset\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","import cv2\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DPWhXaltAYgH"},"source":["# **3. Select your parameters and paths**\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"nAW3oU60htR_"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"HJKFAmuXc6d1"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 400 epochs, but a full training should run for more. Evaluate the performance after training (see 5.). **Default value: 400**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1** \n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** and **`patch_height`:** Input the size of the patches use to train StarDist 3D (length of a side). The value should be smaller or equal to the dimensions of the image. Make patch size and patch_height as large as possible and divisible by 8 and 4, respectively. **Default value: dimension of the training images**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance a cube has 8 corners). **Default value: 96** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**"]},{"cell_type":"code","metadata":{"cellView":"form","id":"CNJImzzVnr7h"},"source":["\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","training_images = Training_source\n","\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","mask_images = Training_target \n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","trained_model = model_path \n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 400#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 1#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","patch_size = 64#@param {type:\"number\"} # pixels in\n","patch_height = 64#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","n_rays = 96 #@param {type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," n_rays = 96\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n","\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters): \n"," patch_size = min(Image_Y, Image_X) \n"," patch_height = Image_Z\n","\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","from astropy.visualization import simple_norm\n","norm = simple_norm(x, percent = 99)\n","\n","mid_plane = int(Image_Z / 2)+1\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off')\n","plt.title('Training source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n","plt.axis('off')\n","plt.title('Training target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_StarDist3D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nbyf-RevQhDL"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"UQ2hultWQlT9"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis as well as performing elastic deformations\n","\n","**The flip option and the elastic deformation will double the size of your dataset, rotation will quadruple and all together will increase the dataset by a factor of 16.**\n","\n"," Elastic deformations performed by [Elasticdeform.](https://elasticdeform.readthedocs.io/en/latest/index.html).\n"]},{"cell_type":"code","metadata":{"id":"wYdTY6ULg01b","cellView":"form"},"source":["#@markdown ###See Elasticdeform’s license\n","#Copyright (c) 2001, 2002 Enthought, Inc. All rights reserved.\n","\n","#Copyright (c) 2003-2017 SciPy Developers. All rights reserved.\n","\n","#Copyright (c) 2018 Gijs van Tulder. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","##Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","#Neither the name of Enthought nor the names of the SciPy Developers may be used to endorse or promote products derived from this software without specific prior written permission.\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","print(\"Double click to see elasticdeform’s license\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kKLB47jgQrxr","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","#@markdown **Deform your images**\n","\n","Elastic_deformation = True #@param {type:\"boolean\"}\n","\n","Deformation_Sigma = 3 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","Save_augmented_images = True #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","\n","\n","if Use_Data_augmentation:\n","\n","\n"," if Elastic_deformation:\n"," !pip install elasticdeform\n"," import numpy, imageio, elasticdeform\n","\n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n","\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n"," Training_source_augmented = Augmented_folder+\"/Training_source\"\n"," os.makedirs(Training_source_augmented)\n"," Training_target_augmented = Augmented_folder+\"/Training_target\"\n"," os.makedirs(Training_target_augmented)\n"," print(\"Data augmentation enabled\")\n"," print(\"Generation of the augmented dataset in progress\")\n","\n"," if Elastic_deformation:\n"," for filename in os.listdir(Training_source):\n"," X = imread(os.path.join(Training_source, filename))\n"," Y = imread(os.path.join(Training_target, filename))\n"," [X_deformed, Y_deformed] = elasticdeform.deform_random_grid([X, Y], sigma=Deformation_Sigma, order=0)\n","\n"," os.chdir(Augmented_folder+\"/Training_source\")\n"," imsave(filename, X)\n"," imsave(filename+\"_deformed.tif\", X_deformed)\n","\n"," os.chdir(Augmented_folder+\"/Training_target\")\n"," imsave(filename, Y)\n"," imsave(filename+\"_deformed.tif\", Y_deformed)\n","\n"," Training_source_rot = Training_source_augmented\n"," Training_target_rot = Training_target_augmented\n"," \n"," if not Elastic_deformation:\n"," Training_source_rot = Training_source\n"," Training_target_rot = Training_target\n","\n"," \n"," if Rotation == True:\n"," rotation_aug(Training_source_rot,Training_target_rot,flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip(Training_source_rot,Training_target_rot)\n","\n"," print(\"Done\")\n","\n"," if Elastic_deformation:\n"," from astropy.visualization import simple_norm\n"," norm = simple_norm(x, percent = 99)\n","\n"," random_choice=random.choice(os.listdir(Training_source))\n"," x = imread(Augmented_folder+\"/Training_source/\"+random_choice)\n"," x_deformed = imread(Augmented_folder+\"/Training_source/\"+random_choice+\"_deformed.tif\")\n"," y = imread(Augmented_folder+\"/Training_target/\"+random_choice)\n"," y_deformed = imread(Augmented_folder+\"/Training_target/\"+random_choice+\"_deformed.tif\") \n","\n"," Image_Z = x.shape[0]\n"," mid_plane = int(Image_Z / 2)+1\n","\n"," f=plt.figure(figsize=(10,10))\n"," plt.subplot(2,2,1)\n"," plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Training source (single Z plane)');\n"," plt.subplot(2,2,2)\n"," plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Training target (single Z plane)');\n"," plt.subplot(2,2,3)\n"," plt.imshow(x_deformed[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Deformed training source (single Z plane)');\n"," plt.subplot(2,2,4)\n"," plt.imshow(y_deformed[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Deformed training target (single Z plane)');\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")\n","\n","\n","\n"," \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"pjz-5bRVh1ja"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"zeSUtd2Thw-O","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Demo_3D_Model_from_Stardist_3D_paper\" #@param [\"Model_from_file\", \"Demo_3D_Model_from_Stardist_3D_paper\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 3D model provided in the Stardist 3D github ------------------------\n","\n"," if pretrained_model_choice == \"Demo_3D_Model_from_Stardist_3D_paper\":\n"," pretrained_model_name = \"Demo_3D\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the Demo 3D model from the Stardist_3D paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"https://raw.githubusercontent.com/mpicbg-csbd/stardist/master/models/examples/3D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/3D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_best.h5?raw=true\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist'+W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print(bcolors.WARNING+'Weights found in:')\n"," print(h5_file_path)\n"," print(bcolors.WARNING+'will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DECuc3HZDbwG"},"source":["#**4. Train the network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"NwV5LweiavgQ"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"uTM781rCKT8r","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","import warnings\n","warnings.simplefilter(\"ignore\")\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","\n","n_channel = 1 if X[0].ndim == 3 else X[0].shape[-1]\n","\n","\n","\n","#Normalize images and fill small label holes.\n","axis_norm = (0,1,2) # normalize channels independently\n","# axis_norm = (0,1,2,3) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 3 in axis_norm else 'independently'))\n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","\n","\n","extents = calculate_extents(Y)\n","anisotropy = tuple(np.max(extents) / extents)\n","print('empirical anisotropy of labeled objects = %s' % str(anisotropy))\n","\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(len(X)/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# Predict on subsampled grid for increased efficiency and larger field of view\n","grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)\n","\n","# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data\n","rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)\n","\n","conf = Config3D (\n"," rays = rays,\n"," grid = grid,\n"," anisotropy = anisotropy,\n"," use_gpu = use_gpu,\n"," n_channel_in = n_channel,\n"," train_learning_rate = initial_learning_rate,\n"," train_patch_size = (patch_height, patch_size, patch_size),\n"," train_batch_size = batch_size,\n",")\n","print(conf)\n","vars(conf)\n","\n","\n","# --------------------- This is currently disabled as it give an error ------------------------\n","#here we limit GPU to 80%\n","if use_gpu:\n"," from csbdeep.utils.tf import limit_gpu_memory\n"," # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations\n"," limit_gpu_memory(0.8)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist3D(conf, name=model_name, basedir=trained_model)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(Y, np.median)\n","fov = np.array(model._axes_tile_overlap('ZYX'))\n","if any(median_size > fov):\n"," print(\"WARNING: median object size larger than field of view of the neural network.\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nnMCvu2PKT9W"},"source":["## **4.2. Start Training**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point."]},{"cell_type":"code","metadata":{"id":"XfCF-Q4lKT9e","cellView":"form"},"source":["import time\n","start = time.time()\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","#@markdown ##Start training\n","\n","augmenter = None\n","\n","# def augmenter(X_batch, Y_batch):\n","# \"\"\"Augmentation for data batch.\n","# X_batch is a list of input images (length at most batch_size)\n","# Y_batch is the corresponding list of ground-truth label images\n","# \"\"\"\n","# # ...\n","# return X_batch, Y_batch\n","\n","# Training the model. \n","# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","print(\"Network optimization in progress\")\n","\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","print(\"Done\")\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","from datetime import datetime\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'StarDist 3D'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and methods:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(int(len(X)))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","#text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(int(len(X)))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n"," if Elastic_deformation == True:\n"," aug_text = aug_text+'\\n- elastic deformation'\n"," if Flip == True:\n"," aug_text = aug_text+'\\n- flipping'\n"," if Rotation == True:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
n_rays{5}
initial_learning_rate{6}
\n","\"\"\".format(number_of_epochs,str(patch_height)+'x'+str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,n_rays,initial_learning_rate)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_StarDist3D.png').shape\n","pdf.image('/content/TrainingDataExample_StarDist3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- StarDist 3D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","ref_3 = '- StarDist 3D: Weigert, Martin, et al. \"Star-convex polyhedra for 3d object detection and segmentation in microscopy.\" The IEEE Winter Conference on Applications of Computer Vision. 2020.'\n","pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","# if Use_Data_augmentation:\n","# ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n","# pdf.multi_cell(190, 5, txt = ref_4, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iYRrmh0dCrNs"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"LqH54fYhdbXU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"RzAHUsi-78Ak","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w3Z7Jkv8bPvq"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"05dbg6UrGunj","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mBkuXf5zhHUd"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"i9ek_kIHhK1R","cellView":"form"},"source":["#@markdown ##Give the paths to an image to test the performance of the model with.\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","\n","model = StarDist3D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," \n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n","#Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n","#Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","# Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n","Image_Z = test_input.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Display the last image\n","\n","f=plt.figure(figsize=(25,25))\n","\n","from astropy.visualization import simple_norm\n","norm = simple_norm(test_input, percent = 99)\n","\n","#Input\n","plt.subplot(1,4,1)\n","plt.axis('off')\n","plt.imshow(test_input[mid_plane], aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n","plt.title('Input')\n","\n","#Ground-truth\n","plt.subplot(1,4,2)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], aspect='equal', cmap='Greens')\n","plt.title('Ground Truth')\n","\n","#Prediction\n","plt.subplot(1,4,3)\n","plt.axis('off')\n","plt.imshow(test_prediction_0_to_255[mid_plane], aspect='equal', cmap='Purples')\n","plt.title('Prediction')\n","\n","#Overlay\n","plt.subplot(1,4,4)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], cmap='Greens')\n","plt.imshow(test_prediction_0_to_255[mid_plane], alpha=0.5, cmap='Purples')\n","plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)))\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","#Make a pdf summary of the QC results\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Stardist 3D'\n","\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n","if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n","pdf.ln(2)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(full_QC_model_path+'/Quality Control/Quality_Control for '+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," PvGT_IoU = header[1]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,PvGT_IoU)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," PvGT_IoU = row[1]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(PvGT_IoU),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}
{0}{1}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = ' - Weigert, Martin, et al. \"Star-convex polyhedra for 3d object detection and segmentation in microscopy.\" The IEEE Winter Conference on Applications of Computer Vision. 2020.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"U8H7QRfKBzI8"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"btXwwnVpBEMB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you trained.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"x8UXP8S2eoo_","cellView":"form"},"source":["from PIL import Image\n","\n","\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","#test_dataset = Data_folder\n","\n","Results_folder = \"\" #@param {type:\"string\"}\n","#results = results_folder\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#single images\n","#testDATA = test_dataset\n","Dataset = Data_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n"," \n","# axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","model = StarDist3D(None, name=Prediction_model_name, basedir=Prediction_model_path)\n"," \n","#Sorting and mapping original test dataset\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","names = [os.path.basename(f) for f in sorted(glob(Dataset))]\n","\n","# modify the names to suitable form: path_images/image_numberX.tif\n","FILEnames=[]\n","for m in names:\n"," m=Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Predictions folder\n","lenght_of_X = len(X)\n","for i in range(lenght_of_X):\n"," img = normalize(X[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," \n","# Save the predicted mask in the result folder\n"," os.chdir(Results_folder)\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," # One example image \n","print(\"One example image is displayed bellow:\")\n","plt.figure(figsize=(13,10))\n","z = max(0, img.shape[0] // 2 - 5)\n","plt.subplot(121)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.title('Raw image (XY slice)')\n","plt.axis('off')\n","plt.subplot(122)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.imshow(labels[z], cmap=lbl_cmap, alpha=0.5)\n","plt.title('Image and predicted labels (XY slice)')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SxJsrw3kTcFx"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"rH_J20ydXWRQ"},"source":["#**Thank you for using StarDist 3D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"StarDist_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1610975750230},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **StarDist (3D)**\n","---\n","\n","**StarDist 3D** is a deep-learning method that can be used to segment cell nuclei from 3D bioimages and was first published by [Weigert *et al.* in 2019 on arXiv](https://arxiv.org/abs/1908.03636), extending to 3D the 2D appraoch from [Schmidt *et al.* in 2018](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 3D network is based on an adapted ResNet network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 3D dataset. If you are interested in 2D dataset, you should use the StarDist 2D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - **Masks** \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","\n","#@markdown ##Install StarDist and dependencies\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools\n","!pip install edt\n","!pip install wget\n","!pip install fpdf\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available\n","from stardist.models import Config3D, StarDist3D, StarDistData3D\n","from stardist import relabel_image_stardist3D, Rays_GoldenSpiral, calculate_extents\n","from stardist.matching import matching_dataset\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","import cv2\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'StarDist 3D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(int(len(X)))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(int(len(X)))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by:'\n"," if Elastic_deformation == True:\n"," aug_text = aug_text+'\\n- elastic deformation'\n"," if Flip == True:\n"," aug_text = aug_text+'\\n- flipping'\n"," if Rotation == True:\n"," aug_text = aug_text+'\\n- rotation'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
n_rays{5}
initial_learning_rate{6}
\n"," \"\"\".format(number_of_epochs,str(patch_height)+'x'+str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,n_rays,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_StarDist3D.png').shape\n"," pdf.image('/content/TrainingDataExample_StarDist3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- StarDist 3D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," ref_3 = '- StarDist 3D: Weigert, Martin, et al. \"Star-convex polyhedra for 3d object detection and segmentation in microscopy.\" The IEEE Winter Conference on Applications of Computer Vision. 2020.'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_4, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Stardist 3D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/Quality_Control for '+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," PvGT_IoU = header[1]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,PvGT_IoU)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," PvGT_IoU = row[1]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(PvGT_IoU),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}
{0}{1}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = ' - Weigert, Martin, et al. \"Star-convex polyhedra for 3d object detection and segmentation in microscopy.\" The IEEE Winter Conference on Applications of Computer Vision. 2020.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n"," !pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 400 epochs, but a full training should run for more. Evaluate the performance after training (see 5.). **Default value: 400**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1** \n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** and **`patch_height`:** Input the size of the patches use to train StarDist 3D (length of a side). The value should be smaller or equal to the dimensions of the image. Make patch size and patch_height as large as possible and divisible by 8 and 4, respectively. **Default value: dimension of the training images**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance a cube has 8 corners). **Default value: 96** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","training_images = Training_source\n","\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","mask_images = Training_target \n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","trained_model = model_path \n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 400#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 1#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","patch_size = 64#@param {type:\"number\"} # pixels in\n","patch_height = 64#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","n_rays = 96 #@param {type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," n_rays = 96\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n","\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters): \n"," patch_size = min(Image_Y, Image_X) \n"," patch_height = Image_Z\n","\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","from astropy.visualization import simple_norm\n","norm = simple_norm(x, percent = 99)\n","\n","mid_plane = int(Image_Z / 2)+1\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off')\n","plt.title('Training source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n","plt.axis('off')\n","plt.title('Training target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_StarDist3D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis as well as performing elastic deformations\n","\n","**The flip option and the elastic deformation will double the size of your dataset, rotation will quadruple and all together will increase the dataset by a factor of 16.**\n","\n"," Elastic deformations performed by [Elasticdeform.](https://elasticdeform.readthedocs.io/en/latest/index.html).\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"7RfiVP7qsPt0"},"source":["#@markdown ###See Elasticdeform’s license\r\n","#Copyright (c) 2001, 2002 Enthought, Inc. All rights reserved.\r\n","\r\n","#Copyright (c) 2003-2017 SciPy Developers. All rights reserved.\r\n","\r\n","#Copyright (c) 2018 Gijs van Tulder. All rights reserved.\r\n","\r\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\r\n","\r\n","##Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\r\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\r\n","#Neither the name of Enthought nor the names of the SciPy Developers may be used to endorse or promote products derived from this software without specific prior written permission.\r\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\r\n","\r\n","print(\"Double click to see elasticdeform’s license\")\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","#@markdown **Deform your images**\n","\n","Elastic_deformation = False #@param {type:\"boolean\"}\n","\n","Deformation_Sigma = 1 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = False #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Training_source_augmented+'/'+image,source_img)\n"," io.imsave(Training_source_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Training_target_augmented+'/'+image,target_img)\n"," io.imsave(Training_target_augmented+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","\n","\n","if Use_Data_augmentation:\n","\n","\n"," if Elastic_deformation:\n"," !pip install elasticdeform\n"," import numpy, imageio, elasticdeform\n","\n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n","\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n"," Training_source_augmented = Augmented_folder+\"/Training_source\"\n"," os.makedirs(Training_source_augmented)\n"," Training_target_augmented = Augmented_folder+\"/Training_target\"\n"," os.makedirs(Training_target_augmented)\n"," print(\"Data augmentation enabled\")\n"," print(\"Generation of the augmented dataset in progress\")\n","\n"," if Elastic_deformation:\n"," for filename in os.listdir(Training_source):\n"," X = imread(os.path.join(Training_source, filename))\n"," Y = imread(os.path.join(Training_target, filename))\n"," [X_deformed, Y_deformed] = elasticdeform.deform_random_grid([X, Y], sigma=Deformation_Sigma, order=0)\n","\n"," os.chdir(Augmented_folder+\"/Training_source\")\n"," imsave(filename, X)\n"," imsave(filename+\"_deformed.tif\", X_deformed)\n","\n"," os.chdir(Augmented_folder+\"/Training_target\")\n"," imsave(filename, Y)\n"," imsave(filename+\"_deformed.tif\", Y_deformed)\n","\n"," Training_source_rot = Training_source_augmented\n"," Training_target_rot = Training_target_augmented\n"," \n"," if not Elastic_deformation:\n"," Training_source_rot = Training_source\n"," Training_target_rot = Training_target\n","\n"," \n"," if Rotation == True:\n"," rotation_aug(Training_source_rot,Training_target_rot,flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip(Training_source_rot,Training_target_rot)\n","\n"," print(\"Done\")\n","\n"," if Elastic_deformation:\n"," from astropy.visualization import simple_norm\n"," norm = simple_norm(x, percent = 99)\n","\n"," random_choice=random.choice(os.listdir(Training_source))\n"," x = imread(Augmented_folder+\"/Training_source/\"+random_choice)\n"," x_deformed = imread(Augmented_folder+\"/Training_source/\"+random_choice+\"_deformed.tif\")\n"," y = imread(Augmented_folder+\"/Training_target/\"+random_choice)\n"," y_deformed = imread(Augmented_folder+\"/Training_target/\"+random_choice+\"_deformed.tif\") \n","\n"," Image_Z = x.shape[0]\n"," mid_plane = int(Image_Z / 2)+1\n","\n"," f=plt.figure(figsize=(10,10))\n"," plt.subplot(2,2,1)\n"," plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Training source (single Z plane)');\n"," plt.subplot(2,2,2)\n"," plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Training target (single Z plane)');\n"," plt.subplot(2,2,3)\n"," plt.imshow(x_deformed[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n"," plt.axis('off')\n"," plt.title('Deformed training source (single Z plane)');\n"," plt.subplot(2,2,4)\n"," plt.imshow(y_deformed[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n"," plt.axis('off')\n"," plt.title('Deformed training target (single Z plane)');\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")\n","\n","\n","\n"," \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = True #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Demo_3D_Model_from_Stardist_3D_paper\" #@param [\"Model_from_file\", \"Demo_3D_Model_from_Stardist_3D_paper\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 3D model provided in the Stardist 3D github ------------------------\n","\n"," if pretrained_model_choice == \"Demo_3D_Model_from_Stardist_3D_paper\":\n"," pretrained_model_name = \"Demo_3D\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the Demo 3D model from the Stardist_3D paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"https://raw.githubusercontent.com/mpicbg-csbd/stardist/master/models/examples/3D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/3D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_best.h5?raw=true\", pretrained_model_path)\n"," wget.download(\"https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist'+W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print(bcolors.WARNING+'Weights found in:')\n"," print(h5_file_path)\n"," print(bcolors.WARNING+'will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","import warnings\n","warnings.simplefilter(\"ignore\")\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","\n","n_channel = 1 if X[0].ndim == 3 else X[0].shape[-1]\n","\n","\n","\n","#Normalize images and fill small label holes.\n","axis_norm = (0,1,2) # normalize channels independently\n","# axis_norm = (0,1,2,3) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 3 in axis_norm else 'independently'))\n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","\n","\n","extents = calculate_extents(Y)\n","anisotropy = tuple(np.max(extents) / extents)\n","print('empirical anisotropy of labeled objects = %s' % str(anisotropy))\n","\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(len(X)/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# Predict on subsampled grid for increased efficiency and larger field of view\n","grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)\n","\n","# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data\n","rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)\n","\n","conf = Config3D (\n"," rays = rays,\n"," grid = grid,\n"," anisotropy = anisotropy,\n"," use_gpu = use_gpu,\n"," n_channel_in = n_channel,\n"," train_learning_rate = initial_learning_rate,\n"," train_patch_size = (patch_height, patch_size, patch_size),\n"," train_batch_size = batch_size,\n",")\n","print(conf)\n","vars(conf)\n","\n","\n","# --------------------- This is currently disabled as it give an error ------------------------\n","#here we limit GPU to 80%\n","if use_gpu:\n"," from csbdeep.utils.tf import limit_gpu_memory\n"," # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations\n"," limit_gpu_memory(0.8)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist3D(conf, name=model_name, basedir=trained_model)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(Y, np.median)\n","fov = np.array(model._axes_tile_overlap('ZYX'))\n","if any(median_size > fov):\n"," print(\"WARNING: median object size larger than field of view of the neural network.\")\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["import time\n","start = time.time()\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","#@markdown ##Start training\n","\n","augmenter = None\n","\n","# def augmenter(X_batch, Y_batch):\n","# \"\"\"Augmentation for data batch.\n","# X_batch is a list of input images (length at most batch_size)\n","# Y_batch is the corresponding list of ground-truth label images\n","# \"\"\"\n","# # ...\n","# return X_batch, Y_batch\n","\n","# Training the model. \n","# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","print(\"Network optimization in progress\")\n","\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","print(\"Done\")\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Give the paths to an image to test the performance of the model with.\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","\n","model = StarDist3D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," \n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n","#Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n","#Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","# Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n","Image_Z = test_input.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Display the last image\n","\n","f=plt.figure(figsize=(25,25))\n","\n","from astropy.visualization import simple_norm\n","norm = simple_norm(test_input, percent = 99)\n","\n","#Input\n","plt.subplot(1,4,1)\n","plt.axis('off')\n","plt.imshow(test_input[mid_plane], aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n","plt.title('Input')\n","\n","#Ground-truth\n","plt.subplot(1,4,2)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], aspect='equal', cmap='Greens')\n","plt.title('Ground Truth')\n","\n","#Prediction\n","plt.subplot(1,4,3)\n","plt.axis('off')\n","plt.imshow(test_prediction_0_to_255[mid_plane], aspect='equal', cmap='Purples')\n","plt.title('Prediction')\n","\n","#Overlay\n","plt.subplot(1,4,4)\n","plt.axis('off')\n","plt.imshow(test_ground_truth_0_to_255[mid_plane], cmap='Greens')\n","plt.imshow(test_prediction_0_to_255[mid_plane], alpha=0.5, cmap='Purples')\n","plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)))\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","#Make a pdf summary of the QC results\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you trained.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["from PIL import Image\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","#test_dataset = Data_folder\n","\n","Results_folder = \"\" #@param {type:\"string\"}\n","#results = results_folder\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#single images\n","#testDATA = test_dataset\n","Dataset = Data_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n"," \n","# axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","model = StarDist3D(None, name=Prediction_model_name, basedir=Prediction_model_path)\n"," \n","#Sorting and mapping original test dataset\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","names = [os.path.basename(f) for f in sorted(glob(Dataset))]\n","\n","# modify the names to suitable form: path_images/image_numberX.tif\n","FILEnames=[]\n","for m in names:\n"," m=Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Predictions folder\n","lenght_of_X = len(X)\n","for i in range(lenght_of_X):\n"," img = normalize(X[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," \n","# Save the predicted mask in the result folder\n"," os.chdir(Results_folder)\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," # One example image \n","print(\"One example image is displayed bellow:\")\n","plt.figure(figsize=(13,10))\n","z = max(0, img.shape[0] // 2 - 5)\n","plt.subplot(121)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.title('Raw image (XY slice)')\n","plt.axis('off')\n","plt.subplot(122)\n","plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n","plt.imshow(labels[z], cmap=lbl_cmap, alpha=0.5)\n","plt.title('Image and predicted labels (XY slice)')\n","plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using StarDist 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Template_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Template_ZeroCostDL4Mic.ipynb index 8b9d19ad..7da85fe0 100644 --- a/Colab_notebooks/Template_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Template_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Template_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1owWtQQucUxUOZMaPh2x_mxe_qXKHCZhp","timestamp":1588074588514},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"5Yw8pstKuR7_"},"source":[" This is a template for a ZeroCostDL4Mic notebook and needs to be filled with appropriate model code and information.\n","\n"," Thank you for contributing to the ZeroCostDL4Mic Project. Please use this notebook as a template for your implementation. When your notebook is completed, please upload it to your github page and send us a link so we can reference your work.\n","\n"," If possible, remember to provide separate training and test datasets (for quality control) containing source and target images with your finished notebooks. This is very useful so that ZeroCostDL4Mic users can test your notebook. "]},{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **Name of the Network**\n","\n","---\n","\n"," Description of the network and link to publication with author reference. [author et al, etc.](URL).\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is inspired from the *Zero-Cost Deep-Learning to Enhance Microscopy* project (ZeroCostDL4Mic) (https://github.com/HenriquesLab/DeepLearning_Collab/wiki) and was created by **Your name**\n","\n","This notebook is based on the following paper: \n","\n","**Original Title of the paper**, Journal, volume, pages, year and complete author list, [link to paper](URL)\n","\n","And source code found in: *provide github link or equivalent if applicable*\n","\n","Provide information on dataset availability and link for download if applicable.\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," Give information on the required structure and dataype of the training dataset.\n","\n"," Provide information on quality control dataset, such as:\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n","\n","from tensorflow.python.client import device_lib \n","device_lib.list_local_devices()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install Name of the network and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["#@markdown ##Install Network and dependencies\n","\n","#Libraries contains information of certain topics. \n","\n","#Put the imported code and libraries here\n","\n","Notebook_version = ['1.11'] #Contact the ZeroCostDL4Mic team to find out about the version number\n","\n","!pip install fpdf\n","\n","\n","print(\"Depencies installed and imported.\")\n","\n","# Exporting requirements.txt for local run \n","# -- the developers should leave this below all the other installations\n","!pip freeze > requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your paths and parameters**\n","\n","---\n","\n","The code below allows the user to enter the paths to where the training data is and to define the training parameters.\n"]},{"cell_type":"markdown","metadata":{"id":"da_R1mCG_PDX"},"source":["## **3.1. Setting the main training parameters**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n"," Fill the parameters here as needed and update the code. Note that the sections containing `Training_source`, `Training target`, `model_name` and `model_path` should appear in your notebook.\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Give estimates for training performance given a number of epochs and provide a default value. **Default value:**\n","\n","**`other_parameters`:**Give other parameters or default values **Default value:**\n","\n","**If additional parameter above affects the training of the notebook give a brief explanation and how problems can be mitigated** \n","\n","\n","**Advanced parameters - experienced users only**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** "]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 50#@param {type:\"number\"}\n","\n","#@markdown Other parameters, add as necessary\n","other_parameters = 80#@param {type:\"number\"} # in pixels\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","number_of_steps = 400#@param {type:\"number\"}\n","batch_size = 16#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n","\n","#Here we define the percentage to use for validation\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# The shape of the images.\n","x = imread(InputFile)\n","y = imread(OutputFile)\n","\n","print('Loaded Input images (number, width, length) =', x.shape)\n","print('Loaded Output images (number, width, length) =', y.shape)\n","print(\"Parameters initiated.\")\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","#We save the example data here to use it in the pdf export of the training\n","plt.savefig('/content/NetworkNameExampleData.png', bbox_inches='tight', pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wA66DlgI_Bya"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"opQ2MwPy_HFC"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the images are square in XY.\n","\n","Add any other information which is necessary to run augmentation with your notebook/data."]},{"cell_type":"code","metadata":{"id":"pcWXnWP0_WRn","cellView":"form"},"source":["#@markdown ###Add any further useful augmentations\n","Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lasWo8w6B5BM"},"source":["## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a model of Your Network**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"Wr5O55VuB6t5","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n","\n","#@markdown ### You will need to add or replace the code that loads any previously trained weights to the notebook here."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.1. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"EZnoS3rb8BSR","cellView":"form"},"source":["import time\n","import csv\n","\n","start = time.time()\n","\n","#@markdown ##Start training\n","\n","# Start Training\n","\n","#Insert the code necessary to initiate training of your model\n","\n","#Note that the notebook should load weights either from the model that is \n","#trained from scratch or if the pretrained weights are used (3.3.)\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","\n","# Most likely this section will require some tinkering \n","# by the user to get the pdf output to look nice.\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = \"Your Network's name\"\n","day = datetime.now()\n","date_time = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and methods:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','torch','scipy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","#text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n","#if Use_pretrained_model:\n","# text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), pytorch (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'. The model was trained from the pretrained model: '+pretrained_model_path+'.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","# if Use_Default_Advanced_Parameters:\n","# pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\n","\"\"\".format(percentage_validation,steps,batch_size)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_'+Network_name+'.png').shape\n","pdf.image('/content/TrainingDataExample_'+Network_name+'.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Name of method: reference'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku"},"source":["##**4.2. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after model_name in section 3. Provide the path to this folder below. \n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"RZOPCVN0qcYb"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n"," Update the code below to perform predictions on your quality control dataset. Use the metrics that are the most meaningful to assess the quality of the prediction.\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"Nh8MlX3sqd_7","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Insert code to activate the pretrained model if necessary. \n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Insert code to perform predictions on all datasets in the Source_QC folder\n","\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT)\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source)\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","\n","#Make a pdf summary of the QC results\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Label-free prediction (fnet)'\n","#model_name = os.path.basename(QC_model_folder)\n","day = datetime.now()\n","date_time = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+date_time\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n","pdf.ln(3)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," PSNR_PvsGT = header[4]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,NRMSE_PvsGT,PSNR_PvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," PSNR_PvsGT = row[4]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(PSNR_PvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}
{0}{1}{2}{3}{4}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Name of method: Reference'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","Fill the below code to perform predictions using your model.\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"9ZmST3JRq-Ho","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model and path to model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after model_name in section 3. Provide the path to this folder below.\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# Activate the (pre-)trained model\n","\n","\n","# Provide the code for performing predictions and saving them\n","\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","plt.figure(figsize=(16,8))\n","\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Predicted output');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw"},"source":["\n","#**Thank you for using YOUR NETWORK!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Template_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611141557911},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"gfIn-nNNhdzh"},"source":[" This is a template for a ZeroCostDL4Mic notebook and needs to be filled with appropriate model code and information.\r\n","\r\n"," Thank you for contributing to the ZeroCostDL4Mic Project. Please use this notebook as a template for your implementation. When your notebook is completed, please upload it to your github page and send us a link so we can reference your work.\r\n","\r\n"," If possible, remember to provide separate training and test datasets (for quality control) containing source and target images with your finished notebooks. This is very useful so that ZeroCostDL4Mic users can test your notebook. "]},{"cell_type":"markdown","metadata":{"id":"Av1qDcfthk1a"},"source":["# **Name of the Network**\r\n","\r\n","---\r\n","\r\n"," Description of the network and link to publication with author reference. [author et al, etc.](URL).\r\n","\r\n","---\r\n","\r\n","*Disclaimer*:\r\n","\r\n","This notebook is inspired from the *Zero-Cost Deep-Learning to Enhance Microscopy* project (ZeroCostDL4Mic) (https://github.com/HenriquesLab/DeepLearning_Collab/wiki) and was created by **Your name**\r\n","\r\n","This notebook is based on the following paper: \r\n","\r\n","**Original Title of the paper**, Journal, volume, pages, year and complete author list, [link to paper](URL)\r\n","\r\n","And source code found in: *provide github link or equivalent if applicable*\r\n","\r\n","Provide information on dataset availability and link for download if applicable.\r\n","\r\n","\r\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"TKktwSaWhq9e"},"source":["# **How to use this notebook?**\r\n","\r\n","---\r\n","\r\n","Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\r\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\r\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\r\n","\r\n","\r\n","---\r\n","###**Structure of a notebook**\r\n","\r\n","The notebook contains two types of cell: \r\n","\r\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\r\n","\r\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\r\n","\r\n","---\r\n","###**Table of contents, Code snippets** and **Files**\r\n","\r\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\r\n","\r\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\r\n","\r\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\r\n","\r\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \r\n","\r\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\r\n","\r\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\r\n","\r\n","---\r\n","###**Making changes to the notebook**\r\n","\r\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\r\n","\r\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\r\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"_v_Jl2QZhvLh"},"source":["#**0. Before getting started**\r\n","---\r\n"," Give information on the required structure and dataype of the training dataset.\r\n","\r\n"," Provide information on quality control dataset, such as:\r\n","\r\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\r\n","\r\n"," **Additionally, the corresponding input and output files need to have the same name**.\r\n","\r\n"," Please note that you currently can **only use .tif files!**\r\n","\r\n","\r\n","Here's a common data structure that can work:\r\n","* Experiment A\r\n"," - **Training dataset**\r\n"," - Low SNR images (Training_source)\r\n"," - img_1.tif, img_2.tif, ...\r\n"," - High SNR images (Training_target)\r\n"," - img_1.tif, img_2.tif, ...\r\n"," - **Quality control dataset**\r\n"," - Low SNR images\r\n"," - img_1.tif, img_2.tif\r\n"," - High SNR images\r\n"," - img_1.tif, img_2.tif\r\n"," - **Data to be predicted**\r\n"," - **Results**\r\n","\r\n","---\r\n","**Important note**\r\n","\r\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\r\n","\r\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\r\n","\r\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\r\n","---"]},{"cell_type":"markdown","metadata":{"id":"EPOJkyFYiA15"},"source":["# **1. Initialise the Colab session**\r\n","---\r\n","\r\n","\r\n","\r\n","\r\n"]},{"cell_type":"markdown","metadata":{"id":"8dvLrwF_iEXS"},"source":["\r\n","## **1.1. Check for GPU access**\r\n","---\r\n","\r\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\r\n","\r\n","Go to **Runtime -> Change the Runtime type**\r\n","\r\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\r\n","\r\n","**Accelator: GPU** *(Graphics processing unit)*\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"8o_-wbDOiIHF"},"source":["#@markdown ##Run this cell to check if you have GPU access\r\n","%tensorflow_version 1.x\r\n","\r\n","import tensorflow as tf\r\n","if tf.test.gpu_device_name()=='':\r\n"," print('You do not have GPU access.') \r\n"," print('Did you change your runtime ?') \r\n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\r\n"," print('Expect slow performance. To access GPU try reconnecting later')\r\n","\r\n","else:\r\n"," print('You have GPU access')\r\n","\r\n","from tensorflow.python.client import device_lib \r\n","device_lib.list_local_devices()\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kEyJvvxSiN6L"},"source":["## **1.2. Mount your Google Drive**\r\n","---\r\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\r\n","\r\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \r\n","\r\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"WWVR1U5tiM9h"},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\r\n","\r\n","#@markdown * Click on the URL. \r\n","\r\n","#@markdown * Sign in your Google Account. \r\n","\r\n","#@markdown * Copy the authorization code. \r\n","\r\n","#@markdown * Enter the authorization code. \r\n","\r\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \r\n","\r\n","#mounts user's Google Drive to Google Colab.\r\n","\r\n","from google.colab import drive\r\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"NvJvtQQgiVDF"},"source":["# **2. Install Name of the network and dependencies**\r\n","---\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"XMi71QrxiZbS"},"source":["#@markdown ##Install Network and dependencies\r\n","\r\n","#Libraries contains information of certain topics. \r\n","\r\n","#Put the imported code and libraries here\r\n","\r\n","Notebook_version = ['1.12'] #Contact the ZeroCostDL4Mic team to find out about the version number\r\n","\r\n","!pip install fpdf\r\n","\r\n","# Below are templates for the function definitions for the export\r\n","# of pdf summaries for training and qc. You will need to adjust these functions\r\n","# with the variables and other parameters as necessary to make them\r\n","# work for your project\r\n","from datetime import datetime\r\n","\r\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\r\n"," # save FPDF() class into a \r\n"," # variable pdf \r\n"," #from datetime import datetime\r\n","\r\n"," class MyFPDF(FPDF, HTMLMixin):\r\n"," pass\r\n","\r\n"," pdf = MyFPDF()\r\n"," pdf.add_page()\r\n"," pdf.set_right_margin(-1)\r\n"," pdf.set_font(\"Arial\", size = 11, style='B') \r\n","\r\n"," Network = \"Your network's name\"\r\n"," day = datetime.now()\r\n"," datetime_str = str(day)[0:10]\r\n","\r\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\r\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \r\n","\r\n"," # add another cell \r\n"," if trained:\r\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\r\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\r\n"," pdf.ln(1)\r\n","\r\n"," Header_2 = 'Information for your materials and methods:'\r\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\r\n","\r\n"," all_packages = ''\r\n"," for requirement in freeze(local_only=True):\r\n"," all_packages = all_packages+requirement+', '\r\n"," #print(all_packages)\r\n","\r\n"," #Main Packages\r\n"," main_packages = ''\r\n"," version_numbers = []\r\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\r\n"," find_name=all_packages.find(name)\r\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\r\n"," #Version numbers only here:\r\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\r\n","\r\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\r\n"," cuda_version = cuda_version.stdout.decode('utf-8')\r\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\r\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\r\n"," gpu_name = gpu_name.stdout.decode('utf-8')\r\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\r\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\r\n"," #print(gpu_name)\r\n","\r\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\r\n"," dataset_size = len(os.listdir(Training_source))\r\n","\r\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\r\n","\r\n"," if pretrained_model:\r\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\r\n","\r\n"," pdf.set_font('')\r\n"," pdf.set_font_size(10.)\r\n"," pdf.multi_cell(190, 5, txt = text, align='L')\r\n"," pdf.set_font('')\r\n"," pdf.set_font('Arial', size = 10, style = 'B')\r\n"," pdf.ln(1)\r\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\r\n"," pdf.set_font('')\r\n"," if augmentation:\r\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\r\n"," if rotate_270_degrees != 0 or rotate_90_degrees != 0:\r\n"," aug_text = aug_text+'\\n- rotation'\r\n"," if flip_left_right != 0 or flip_top_bottom != 0:\r\n"," aug_text = aug_text+'\\n- flipping'\r\n"," if random_zoom_magnification != 0:\r\n"," aug_text = aug_text+'\\n- random zoom magnification'\r\n"," if random_distortion != 0:\r\n"," aug_text = aug_text+'\\n- random distortion'\r\n"," if image_shear != 0:\r\n"," aug_text = aug_text+'\\n- image shearing'\r\n"," if skew_image != 0:\r\n"," aug_text = aug_text+'\\n- image skewing'\r\n"," else:\r\n"," aug_text = 'No augmentation was used for training.'\r\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\r\n"," pdf.set_font('Arial', size = 11, style = 'B')\r\n"," pdf.ln(1)\r\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\r\n"," pdf.set_font('')\r\n"," pdf.set_font_size(10.)\r\n"," if Use_Default_Advanced_Parameters:\r\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\r\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\r\n"," pdf.ln(1)\r\n"," html = \"\"\" \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n","
ParameterValue
number_of_epochs{0}
patch_size{1}
number_of_patches{2}
batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
\r\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\r\n"," pdf.write_html(html)\r\n","\r\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\r\n"," pdf.set_font(\"Arial\", size = 11, style='B')\r\n"," pdf.ln(1)\r\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\r\n"," pdf.set_font('')\r\n"," pdf.set_font('Arial', size = 10, style = 'B')\r\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\r\n"," pdf.set_font('')\r\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\r\n"," pdf.set_font('')\r\n"," pdf.set_font('Arial', size = 10, style = 'B')\r\n"," pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)\r\n"," pdf.set_font('')\r\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\r\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\r\n"," pdf.ln(1)\r\n"," pdf.set_font('')\r\n"," pdf.set_font('Arial', size = 10, style = 'B')\r\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\r\n"," pdf.set_font('')\r\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\r\n"," pdf.ln(1)\r\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\r\n"," pdf.ln(1)\r\n"," exp_size = io.imread(\"/content/NetworkNameExampleData.png\").shape\r\n"," pdf.image(\"/content/NetworkNameExampleData.png\", x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\r\n"," pdf.ln(1)\r\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\r\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\r\n"," ref_2 = '- Your networks name: first author et al. \"Title of publication\" Journal, year'\r\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\r\n"," if augmentation:\r\n"," ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\r\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\r\n"," pdf.ln(3)\r\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\r\n"," pdf.set_font('Arial', size = 11, style='B')\r\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\r\n","\r\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\r\n","\r\n","\r\n","#Make a pdf summary of the QC results\r\n","\r\n","def qc_pdf_export():\r\n"," class MyFPDF(FPDF, HTMLMixin):\r\n"," pass\r\n","\r\n"," pdf = MyFPDF()\r\n"," pdf.add_page()\r\n"," pdf.set_right_margin(-1)\r\n"," pdf.set_font(\"Arial\", size = 11, style='B') \r\n","\r\n"," Network = \"Your network's name\"\r\n"," #model_name = os.path.basename(full_QC_model_path)\r\n"," day = datetime.now()\r\n"," datetime_str = str(day)[0:10]\r\n","\r\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\r\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \r\n","\r\n"," all_packages = ''\r\n"," for requirement in freeze(local_only=True):\r\n"," all_packages = all_packages+requirement+', '\r\n","\r\n"," pdf.set_font('')\r\n"," pdf.set_font('Arial', size = 11, style = 'B')\r\n"," pdf.ln(2)\r\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\r\n"," pdf.ln(1)\r\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\r\n"," if os.path.exists(full_QC_model_path+'Quality Control/lossCurvePlots.png'):\r\n"," pdf.image(full_QC_model_path+'Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\r\n"," else:\r\n"," pdf.set_font('')\r\n"," pdf.set_font('Arial', size=10)\r\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.', align='L')\r\n"," pdf.ln(2)\r\n"," pdf.set_font('')\r\n"," pdf.set_font('Arial', size = 10, style = 'B')\r\n"," pdf.ln(3)\r\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\r\n"," pdf.ln(1)\r\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\r\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\r\n"," pdf.ln(1)\r\n"," pdf.set_font('')\r\n"," pdf.set_font('Arial', size = 11, style = 'B')\r\n"," pdf.ln(1)\r\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\r\n"," pdf.set_font('')\r\n"," pdf.set_font_size(10.)\r\n","\r\n"," pdf.ln(1)\r\n"," html = \"\"\"\r\n"," \r\n"," \r\n"," \"\"\"\r\n"," with open(full_QC_model_path+'Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\r\n"," metrics = csv.reader(csvfile)\r\n"," header = next(metrics)\r\n"," image = header[0]\r\n"," mSSIM_PvsGT = header[1]\r\n"," mSSIM_SvsGT = header[2]\r\n"," NRMSE_PvsGT = header[3]\r\n"," NRMSE_SvsGT = header[4]\r\n"," PSNR_PvsGT = header[5]\r\n"," PSNR_SvsGT = header[6]\r\n"," header = \"\"\"\r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\r\n"," html = html+header\r\n"," for row in metrics:\r\n"," image = row[0]\r\n"," mSSIM_PvsGT = row[1]\r\n"," mSSIM_SvsGT = row[2]\r\n"," NRMSE_PvsGT = row[3]\r\n"," NRMSE_SvsGT = row[4]\r\n"," PSNR_PvsGT = row[5]\r\n"," PSNR_SvsGT = row[6]\r\n"," cells = \"\"\"\r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\r\n"," html = html+cells\r\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\r\n","\r\n"," pdf.write_html(html)\r\n","\r\n"," pdf.ln(1)\r\n"," pdf.set_font('')\r\n"," pdf.set_font_size(10.)\r\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\r\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\r\n"," ref_2 = '- Your networks name: first author et al. \"Title of publication\" Journal, year'\r\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\r\n","\r\n"," pdf.ln(3)\r\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\r\n","\r\n"," pdf.set_font('Arial', size = 11, style='B')\r\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\r\n","\r\n"," pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\r\n","\r\n","print(\"Depencies installed and imported.\")\r\n","\r\n","# Exporting requirements.txt for local run \r\n","# -- the developers should leave this below all the other installations\r\n","!pip freeze > requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jKaeBnSuifZn"},"source":["# **3. Select your paths and parameters**\r\n","\r\n","---\r\n","\r\n","The code below allows the user to enter the paths to where the training data is and to define the training parameters.\r\n"]},{"cell_type":"markdown","metadata":{"id":"StTGluw2iidc"},"source":["## **3.1. Setting the main training parameters**\r\n","---\r\n",""]},{"cell_type":"markdown","metadata":{"id":"GyRjBdClimfK"},"source":[" **Paths for training, predictions and results**\r\n","\r\n"," Fill the parameters here as needed and update the code. Note that the sections containing `Training_source`, `Training target`, `model_name` and `model_path` should appear in your notebook.\r\n","\r\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\r\n","\r\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\r\n","\r\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\r\n","\r\n","**Training parameters**\r\n","\r\n","**`number_of_epochs`:**Give estimates for training performance given a number of epochs and provide a default value. **Default value:**\r\n","\r\n","**`other_parameters`:**Give other parameters or default values **Default value:**\r\n","\r\n","**If additional parameter above affects the training of the notebook give a brief explanation and how problems can be mitigated** \r\n","\r\n","\r\n","**Advanced parameters - experienced users only**\r\n","\r\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\r\n","\r\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\r\n","\r\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** "]},{"cell_type":"code","metadata":{"cellView":"form","id":"i1sKnXrDieiR"},"source":["class bcolors:\r\n"," WARNING = '\\033[31m'\r\n","\r\n","#@markdown ###Path to training images:\r\n","\r\n","Training_source = \"\" #@param {type:\"string\"}\r\n","\r\n","# Ground truth images\r\n","Training_target = \"\" #@param {type:\"string\"}\r\n","\r\n","# model name and path\r\n","#@markdown ###Name of the model and path to model folder:\r\n","model_name = \"\" #@param {type:\"string\"}\r\n","model_path = \"\" #@param {type:\"string\"}\r\n","\r\n","\r\n","# other parameters for training.\r\n","#@markdown ###Training Parameters\r\n","#@markdown Number of epochs:\r\n","\r\n","number_of_epochs = 50#@param {type:\"number\"}\r\n","\r\n","#@markdown Other parameters, add as necessary\r\n","other_parameters = 80#@param {type:\"number\"} # in pixels\r\n","\r\n","\r\n","#@markdown ###Advanced Parameters\r\n","\r\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\r\n","#@markdown ###If not, please input:\r\n","\r\n","number_of_steps = 400#@param {type:\"number\"}\r\n","batch_size = 16#@param {type:\"number\"}\r\n","percentage_validation = 10 #@param {type:\"number\"}\r\n","\r\n","\r\n","if (Use_Default_Advanced_Parameters): \r\n"," print(\"Default advanced parameters enabled\")\r\n"," batch_size = 16\r\n"," percentage_validation = 10\r\n","\r\n","#Here we define the percentage to use for validation\r\n","percentage = percentage_validation/100\r\n","\r\n","\r\n","#here we check that no model with the same name already exist, if so delete\r\n","if os.path.exists(model_path+'/'+model_name):\r\n"," shutil.rmtree(model_path+'/'+model_name)\r\n","\r\n","\r\n","# The shape of the images.\r\n","x = imread(InputFile)\r\n","y = imread(OutputFile)\r\n","\r\n","print('Loaded Input images (number, width, length) =', x.shape)\r\n","print('Loaded Output images (number, width, length) =', y.shape)\r\n","print(\"Parameters initiated.\")\r\n","\r\n","# This will display a randomly chosen dataset input and output\r\n","random_choice = random.choice(os.listdir(Training_source))\r\n","x = imread(Training_source+\"/\"+random_choice)\r\n","\r\n","\r\n","# Here we check that the input images contains the expected dimensions\r\n","if len(x.shape) == 2:\r\n"," print(\"Image dimensions (y,x)\",x.shape)\r\n","\r\n","if not len(x.shape) == 2:\r\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\r\n","\r\n","\r\n","#Find image XY dimension\r\n","Image_Y = x.shape[0]\r\n","Image_X = x.shape[1]\r\n","\r\n","#Hyperparameters failsafes\r\n","\r\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \r\n","\r\n","if patch_size > min(Image_Y, Image_X):\r\n"," patch_size = min(Image_Y, Image_X)\r\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\r\n","\r\n","# Here we check that patch_size is divisible by 8\r\n","if not patch_size % 8 == 0:\r\n"," patch_size = ((int(patch_size / 8)-1) * 8)\r\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\r\n","\r\n","\r\n","\r\n","os.chdir(Training_target)\r\n","y = imread(Training_target+\"/\"+random_choice)\r\n","\r\n","\r\n","f=plt.figure(figsize=(16,8))\r\n","plt.subplot(1,2,1)\r\n","plt.imshow(x, interpolation='nearest')\r\n","plt.title('Training source')\r\n","plt.axis('off');\r\n","\r\n","plt.subplot(1,2,2)\r\n","plt.imshow(y, interpolation='nearest')\r\n","plt.title('Training target')\r\n","plt.axis('off');\r\n","#We save the example data here to use it in the pdf export of the training\r\n","plt.savefig('/content/NetworkNameExampleData.png', bbox_inches='tight', pad_inches=0)\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"VLYZQA6GitQL"},"source":["## **3.2. Data augmentation**\r\n","---\r\n",""]},{"cell_type":"markdown","metadata":{"id":"M4GfK6-1iwbf"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\r\n","\r\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the images are square in XY.\r\n","\r\n","Add any other information which is necessary to run augmentation with your notebook/data."]},{"cell_type":"code","metadata":{"cellView":"form","id":"EkBGtraZi3Ob"},"source":["#@markdown ###Add any further useful augmentations\r\n","Use_Data_augmentation = False #@param{type:\"boolean\"}\r\n","\r\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\r\n","\r\n","#@markdown **Rotate each image 3 times by 90 degrees.**\r\n","Rotation = True #@param{type:\"boolean\"}\r\n","\r\n","#@markdown **Flip each image once around the x axis of the stack.**\r\n","Flip = True #@param{type:\"boolean\"}\r\n","\r\n","\r\n","#@markdown **Would you like to save your augmented images?**\r\n","\r\n","Save_augmented_images = False #@param {type:\"boolean\"}\r\n","\r\n","Saving_path = \"\" #@param {type:\"string\"}\r\n","\r\n","\r\n","if not Save_augmented_images:\r\n"," Saving_path= \"/content\"\r\n","\r\n","\r\n","def rotation_aug(Source_path, Target_path, flip=False):\r\n"," Source_images = os.listdir(Source_path)\r\n"," Target_images = os.listdir(Target_path)\r\n"," \r\n"," for image in Source_images:\r\n"," source_img = io.imread(os.path.join(Source_path,image))\r\n"," target_img = io.imread(os.path.join(Target_path,image))\r\n"," \r\n"," # Source Rotation\r\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\r\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\r\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\r\n","\r\n"," # Target Rotation\r\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\r\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\r\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\r\n","\r\n"," # Add a flip to the rotation\r\n"," \r\n"," if flip == True:\r\n"," source_img_lr = np.fliplr(source_img)\r\n"," source_img_90_lr = np.fliplr(source_img_90)\r\n"," source_img_180_lr = np.fliplr(source_img_180)\r\n"," source_img_270_lr = np.fliplr(source_img_270)\r\n","\r\n"," target_img_lr = np.fliplr(target_img)\r\n"," target_img_90_lr = np.fliplr(target_img_90)\r\n"," target_img_180_lr = np.fliplr(target_img_180)\r\n"," target_img_270_lr = np.fliplr(target_img_270)\r\n","\r\n"," #source_img_90_ud = np.flipud(source_img_90)\r\n"," \r\n"," # Save the augmented files\r\n"," # Source images\r\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\r\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\r\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\r\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\r\n"," # Target images\r\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\r\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\r\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\r\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\r\n","\r\n"," if flip == True:\r\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\r\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\r\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\r\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\r\n","\r\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\r\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\r\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\r\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\r\n","\r\n","def flip(Source_path, Target_path):\r\n"," Source_images = os.listdir(Source_path)\r\n"," Target_images = os.listdir(Target_path) \r\n","\r\n"," for image in Source_images:\r\n"," source_img = io.imread(os.path.join(Source_path,image))\r\n"," target_img = io.imread(os.path.join(Target_path,image))\r\n"," \r\n"," source_img_lr = np.fliplr(source_img)\r\n"," target_img_lr = np.fliplr(target_img)\r\n","\r\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\r\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\r\n","\r\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\r\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\r\n","\r\n","\r\n","if Use_Data_augmentation:\r\n","\r\n"," if os.path.exists(Saving_path+'/augmented_source'):\r\n"," shutil.rmtree(Saving_path+'/augmented_source')\r\n"," os.mkdir(Saving_path+'/augmented_source')\r\n","\r\n"," if os.path.exists(Saving_path+'/augmented_target'):\r\n"," shutil.rmtree(Saving_path+'/augmented_target') \r\n"," os.mkdir(Saving_path+'/augmented_target')\r\n","\r\n"," print(\"Data augmentation enabled\")\r\n"," print(\"Data augmentation in progress....\")\r\n","\r\n"," if Rotation == True:\r\n"," rotation_aug(Training_source,Training_target,flip=Flip)\r\n"," \r\n"," elif Rotation == False and Flip == True:\r\n"," flip(Training_source,Training_target)\r\n"," print(\"Done\")\r\n","\r\n","\r\n","if not Use_Data_augmentation:\r\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-Y-47ZmFiyG_"},"source":["## **3.3. Using weights from a pre-trained model as initial weights**\r\n","---\r\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a model of Your Network**. \r\n","\r\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\r\n","\r\n"," In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"cellView":"form","id":"jSb9luhrjHe-"},"source":["# @markdown ##Loading weights from a pre-trained network\r\n","\r\n","Use_pretrained_model = False #@param {type:\"boolean\"}\r\n","\r\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\r\n","\r\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\r\n","\r\n","\r\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\r\n","pretrained_model_path = \"\" #@param {type:\"string\"}\r\n","\r\n","# --------------------- Check if we load a previously trained model ------------------------\r\n","if Use_pretrained_model:\r\n","\r\n","# --------------------- Load the model from the choosen path ------------------------\r\n"," if pretrained_model_choice == \"Model_from_file\":\r\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\r\n","\r\n","\r\n","# --------------------- Download the a model provided in the XXX ------------------------\r\n","\r\n"," if pretrained_model_choice == \"Model_name\":\r\n"," pretrained_model_name = \"Model_name\"\r\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\r\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\r\n"," if os.path.exists(pretrained_model_path):\r\n"," shutil.rmtree(pretrained_model_path)\r\n"," os.makedirs(pretrained_model_path)\r\n"," wget.download(\"\", pretrained_model_path)\r\n"," wget.download(\"\", pretrained_model_path)\r\n"," wget.download(\"\", pretrained_model_path) \r\n"," wget.download(\"\", pretrained_model_path)\r\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\r\n","\r\n","# --------------------- Add additional pre-trained models here ------------------------\r\n","\r\n","# --------------------- Check the model exist ------------------------\r\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \r\n"," if not os.path.exists(h5_file_path):\r\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\r\n"," Use_pretrained_model = False\r\n","\r\n"," \r\n","# If the model path contains a pretrain model, we load the training rate, \r\n"," if os.path.exists(h5_file_path):\r\n","#Here we check if the learning rate can be loaded from the quality control folder\r\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\r\n","\r\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\r\n"," csvRead = pd.read_csv(csvfile, sep=',')\r\n"," #print(csvRead)\r\n"," \r\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\r\n"," print(\"pretrained network learning rate found\")\r\n"," #find the last learning rate\r\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\r\n"," #Find the learning rate corresponding to the lowest validation loss\r\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\r\n"," #print(min_val_loss)\r\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\r\n","\r\n"," if Weights_choice == \"last\":\r\n"," print('Last learning rate: '+str(lastLearningRate))\r\n","\r\n"," if Weights_choice == \"best\":\r\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\r\n","\r\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\r\n"," bestLearningRate = initial_learning_rate\r\n"," lastLearningRate = initial_learning_rate\r\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\r\n","\r\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\r\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\r\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\r\n"," bestLearningRate = initial_learning_rate\r\n"," lastLearningRate = initial_learning_rate\r\n","\r\n","\r\n","# Display info about the pretrained model to be loaded (or not)\r\n","if Use_pretrained_model:\r\n"," print('Weights found in:')\r\n"," print(h5_file_path)\r\n"," print('will be loaded prior to training.')\r\n","\r\n","else:\r\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\r\n","\r\n","\r\n","#@markdown ### You will need to add or replace the code that loads any previously trained weights to the notebook here."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sjTtP2OmjMqM"},"source":["# **4. Train the network**\r\n","---"]},{"cell_type":"markdown","metadata":{"id":"yQ9NgI6XjQIk"},"source":["## **4.1. Train the network**\r\n","---\r\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\r\n","\r\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\r\n","\r\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"cellView":"form","id":"SVUd0Lr0jUjy"},"source":["import time\r\n","import csv\r\n","\r\n","# Export the training parameters as pdf (before training, in case training fails) \r\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\r\n","\r\n","start = time.time()\r\n","\r\n","#@markdown ##Start training\r\n","\r\n","# Start Training\r\n","\r\n","#Insert the code necessary to initiate training of your model\r\n","\r\n","#Note that the notebook should load weights either from the model that is \r\n","#trained from scratch or if the pretrained weights are used (3.3.)\r\n","\r\n","# Displaying the time elapsed for training\r\n","dt = time.time() - start\r\n","mins, sec = divmod(dt, 60) \r\n","hour, mins = divmod(mins, 60) \r\n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\r\n","\r\n","# Export the training parameters as pdf (after training)\r\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1Tm3aimXjZ1B"},"source":["# **5. Evaluate your model**\r\n","---\r\n","\r\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \r\n","\r\n","**We highly recommend to perform quality control on all newly trained models.**\r\n","\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"QAXu1FR0jYZC"},"source":["# model name and path\r\n","#@markdown ###Do you want to assess the model you just trained ?\r\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\r\n","\r\n","#@markdown ###If not, please provide the name of the model and path to model folder:\r\n","#@markdown #####During training, the model files are automatically saved inside a folder named after model_name in section 3. Provide the path to this folder below. \r\n","\r\n","QC_model_folder = \"\" #@param {type:\"string\"}\r\n","\r\n","#Here we define the loaded model name and path\r\n","QC_model_name = os.path.basename(QC_model_folder)\r\n","QC_model_path = os.path.dirname(QC_model_folder)\r\n","\r\n","if (Use_the_current_trained_model): \r\n"," QC_model_name = model_name\r\n"," QC_model_path = model_path\r\n","\r\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\r\n","if os.path.exists(full_QC_model_path):\r\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\r\n","else:\r\n"," W = '\\033[0m' # white (normal)\r\n"," R = '\\033[31m' # red\r\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\r\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ULMuc37njkXM"},"source":["## **5.1. Inspection of the loss function**\r\n","---\r\n","\r\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\r\n","\r\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\r\n","\r\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\r\n","\r\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\r\n","\r\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","id":"1VCvEofKjjHN"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\r\n","import csv\r\n","from matplotlib import pyplot as plt\r\n","\r\n","lossDataFromCSV = []\r\n","vallossDataFromCSV = []\r\n","\r\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\r\n"," csvRead = csv.reader(csvfile, delimiter=',')\r\n"," next(csvRead)\r\n"," for row in csvRead:\r\n"," lossDataFromCSV.append(float(row[0]))\r\n"," vallossDataFromCSV.append(float(row[1]))\r\n","\r\n","epochNumber = range(len(lossDataFromCSV))\r\n","plt.figure(figsize=(15,10))\r\n","\r\n","plt.subplot(2,1,1)\r\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\r\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\r\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\r\n","plt.ylabel('Loss')\r\n","plt.xlabel('Epoch number')\r\n","plt.legend()\r\n","\r\n","plt.subplot(2,1,2)\r\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\r\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\r\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\r\n","plt.ylabel('Loss')\r\n","plt.xlabel('Epoch number')\r\n","plt.legend()\r\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\r\n","plt.show()\r\n","\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"smiWe2wcjwTc"},"source":["## **5.2. Error mapping and quality metrics estimation**\r\n","---\r\n","\r\n"," Update the code below to perform predictions on your quality control dataset. Use the metrics that are the most meaningful to assess the quality of the prediction.\r\n","\r\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\r\n","\r\n","**1. The SSIM (structural similarity) map** \r\n","\r\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \r\n","\r\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\r\n","\r\n","**The output below shows the SSIM maps with the mSSIM**\r\n","\r\n","**2. The RSE (Root Squared Error) map** \r\n","\r\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\r\n","\r\n","\r\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\r\n","\r\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\r\n","\r\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\r\n","\r\n","\r\n","\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"Z179Zxgtj0PP"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\r\n","\r\n","from skimage.metrics import structural_similarity\r\n","from skimage.metrics import peak_signal_noise_ratio as psnr\r\n","\r\n","Source_QC_folder = \"\" #@param{type:\"string\"}\r\n","Target_QC_folder = \"\" #@param{type:\"string\"}\r\n","\r\n","# Create a quality control/Prediction Folder\r\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\r\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\r\n","\r\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\r\n","\r\n","# Insert code to activate the pretrained model if necessary. \r\n","\r\n","# List Tif images in Source_QC_folder\r\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\r\n","Z = sorted(glob(Source_QC_folder_tif))\r\n","Z = list(map(imread,Z))\r\n","print('Number of test dataset found in the folder: '+str(len(Z)))\r\n","\r\n","\r\n","# Insert code to perform predictions on all datasets in the Source_QC folder\r\n","\r\n","\r\n","def ssim(img1, img2):\r\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\r\n","\r\n","\r\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\r\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\r\n"," \"\"\"Percentile-based image normalization.\"\"\"\r\n","\r\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\r\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\r\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\r\n","\r\n","\r\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\r\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\r\n"," if dtype is not None:\r\n"," x = x.astype(dtype,copy=False)\r\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\r\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\r\n"," eps = dtype(eps)\r\n","\r\n"," try:\r\n"," import numexpr\r\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\r\n"," except ImportError:\r\n"," x = (x - mi) / ( ma - mi + eps )\r\n","\r\n"," if clip:\r\n"," x = np.clip(x,0,1)\r\n","\r\n"," return x\r\n","\r\n","def norm_minmse(gt, x, normalize_gt=True):\r\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\r\n","\r\n"," \"\"\"\r\n"," normalizes and affinely scales an image pair such that the MSE is minimized \r\n"," \r\n"," Parameters\r\n"," ----------\r\n"," gt: ndarray\r\n"," the ground truth image \r\n"," x: ndarray\r\n"," the image that will be affinely scaled \r\n"," normalize_gt: bool\r\n"," set to True of gt image should be normalized (default)\r\n"," Returns\r\n"," -------\r\n"," gt_scaled, x_scaled \r\n"," \"\"\"\r\n"," if normalize_gt:\r\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\r\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\r\n"," #x = x - np.mean(x)\r\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\r\n"," #gt = gt - np.mean(gt)\r\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\r\n"," return gt, scale * x\r\n","\r\n","# Open and create the csv file that will contain all the QC metrics\r\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\r\n"," writer = csv.writer(file)\r\n","\r\n"," # Write the header in the csv file\r\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \r\n","\r\n"," # Let's loop through the provided dataset in the QC folders\r\n","\r\n","\r\n"," for i in os.listdir(Source_QC_folder):\r\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\r\n"," print('Running QC on: '+i)\r\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\r\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\r\n","\r\n"," # -------------------------------- Source test data --------------------------------\r\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\r\n","\r\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\r\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\r\n","\r\n"," # -------------------------------- Prediction --------------------------------\r\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\r\n","\r\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\r\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \r\n","\r\n","\r\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\r\n","\r\n"," # Calculate the SSIM maps\r\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\r\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\r\n","\r\n"," #Save ssim_maps\r\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\r\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\r\n"," \r\n"," # Calculate the Root Squared Error (RSE) maps\r\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\r\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\r\n","\r\n"," # Save SE maps\r\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\r\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\r\n","\r\n","\r\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\r\n","\r\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\r\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\r\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\r\n"," \r\n"," # We can also measure the peak signal to noise ratio between the images\r\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\r\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\r\n","\r\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\r\n","\r\n","\r\n","# All data is now processed saved\r\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\r\n","\r\n","plt.figure(figsize=(15,15))\r\n","# Currently only displays the last computed set, from memory\r\n","# Target (Ground-truth)\r\n","plt.subplot(3,3,1)\r\n","plt.axis('off')\r\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\r\n","plt.imshow(img_GT)\r\n","plt.title('Target',fontsize=15)\r\n","\r\n","# Source\r\n","plt.subplot(3,3,2)\r\n","plt.axis('off')\r\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\r\n","plt.imshow(img_Source)\r\n","plt.title('Source',fontsize=15)\r\n","\r\n","#Prediction\r\n","plt.subplot(3,3,3)\r\n","plt.axis('off')\r\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\r\n","plt.imshow(img_Prediction)\r\n","plt.title('Prediction',fontsize=15)\r\n","\r\n","#Setting up colours\r\n","cmap = plt.cm.CMRmap\r\n","\r\n","#SSIM between GT and Source\r\n","plt.subplot(3,3,5)\r\n","#plt.axis('off')\r\n","plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False) \r\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\r\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\r\n","plt.title('Target vs. Source',fontsize=15)\r\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\r\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\r\n","\r\n","#SSIM between GT and Prediction\r\n","plt.subplot(3,3,6)\r\n","#plt.axis('off')\r\n","plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False) \r\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\r\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\r\n","plt.title('Target vs. Prediction',fontsize=15)\r\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\r\n","\r\n","#Root Squared Error between GT and Source\r\n","plt.subplot(3,3,8)\r\n","#plt.axis('off')\r\n","plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False) \r\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\r\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\r\n","plt.title('Target vs. Source',fontsize=15)\r\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\r\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\r\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\r\n","\r\n","#Root Squared Error between GT and Prediction\r\n","plt.subplot(3,3,9)\r\n","#plt.axis('off')\r\n","plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False) \r\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\r\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\r\n","plt.title('Target vs. Prediction',fontsize=15)\r\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\r\n","\r\n","\r\n","#Make a pdf summary of the QC results\r\n","\r\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fB8QNLekkCyZ"},"source":["# **6. Using the trained model**\r\n","\r\n","---\r\n","\r\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"B2DrAOANkIWu"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\r\n","---\r\n","Fill the below code to perform predictions using your model.\r\n","\r\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\r\n","\r\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\r\n","\r\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"cellView":"form","id":"mELG8z-ykCKV"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\r\n","\r\n","Data_folder = \"\" #@param {type:\"string\"}\r\n","Result_folder = \"\" #@param {type:\"string\"}\r\n","\r\n","# model name and path\r\n","#@markdown ###Do you want to use the current trained model?\r\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\r\n","\r\n","#@markdown ###If not, provide the name of the model and path to model folder:\r\n","#@markdown #####During training, the model files are automatically saved inside a folder named after model_name in section 3. Provide the path to this folder below.\r\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\r\n","\r\n","#Here we find the loaded model name and parent path\r\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\r\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\r\n","\r\n","if (Use_the_current_trained_model): \r\n"," print(\"Using current trained network\")\r\n"," Prediction_model_name = model_name\r\n"," Prediction_model_path = model_path\r\n","\r\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\r\n","if os.path.exists(full_Prediction_model_path):\r\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\r\n","else:\r\n"," W = '\\033[0m' # white (normal)\r\n"," R = '\\033[31m' # red\r\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\r\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\r\n","\r\n","\r\n","# Activate the (pre-)trained model\r\n","\r\n","\r\n","# Provide the code for performing predictions and saving them\r\n","\r\n","\r\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JnSk14AJkRtJ"},"source":["## **6.2. Inspect the predicted output**\r\n","---\r\n","\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"hlkZUhj4kQ2Z"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\r\n","\r\n","# This will display a randomly chosen dataset input and predicted output\r\n","random_choice = random.choice(os.listdir(Data_folder))\r\n","x = imread(Data_folder+\"/\"+random_choice)\r\n","\r\n","os.chdir(Result_folder)\r\n","y = imread(Result_folder+\"/\"+random_choice)\r\n","\r\n","plt.figure(figsize=(16,8))\r\n","\r\n","plt.subplot(1,2,1)\r\n","plt.axis('off')\r\n","plt.imshow(x, interpolation='nearest')\r\n","plt.title('Input')\r\n","\r\n","plt.subplot(1,2,2)\r\n","plt.axis('off')\r\n","plt.imshow(y, interpolation='nearest')\r\n","plt.title('Predicted output');\r\n","\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"gP7WDm6bkYkb"},"source":["## **6.3. Download your predictions**\r\n","---\r\n","\r\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"JbOn8U-VkerU"},"source":["\r\n","#**Thank you for using YOUR NETWORK!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/U-Net_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/U-Net_2D_ZeroCostDL4Mic.ipynb index 87778df0..39c4a627 100644 --- a/Colab_notebooks/U-Net_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/U-Net_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1vOE2M7clX0zm-5wOquL29VP7-nqeHPbJ","timestamp":1602259779758},{"file_id":"1VcTsLOL28ntbr23gYrhY3upxkztZeUvn","timestamp":1591024690909},{"file_id":"19jT_GoHGN-UTM1aEgkgrOjB8pcFz5AW4","timestamp":1591017297795},{"file_id":"1UkoWB27ZWh5j_qivSZIOeOJP1h2EqrVz","timestamp":1589363183397},{"file_id":"1ofNqOc7lz-m6NL4B-m4BIheaU5N0GMln","timestamp":1588873191434},{"file_id":"1rJnsgIKyL6vuneydIfjCKMtMhV3XlQ6o","timestamp":1588583580765},{"file_id":"1RUYrp8beEgDKL1kOWw5LgR1QQb4yHQtG","timestamp":1587061416704},{"file_id":"1FVax0eY3-m8DbJHx0B8Dnep-uGlp30Zt","timestamp":1586601038120},{"file_id":"1TTqmCf2mFQ_PNIZEXX9sRAhoixjYP_AB","timestamp":1585842446113},{"file_id":"1cWwS-jbLYTDOpPp_hhKOLGFXfu06ccpG","timestamp":1585821375983},{"file_id":"1TPEE_AtGTLedawgVBwwXofEJEcJUCgo3","timestamp":1585137343783},{"file_id":"1SxFRb38aC_kmKzKVQfkwWzkK9n7YFxVv","timestamp":1585053829456},{"file_id":"15iw9IOwHNF_GhiHxkh_rWbJG8JnW14Wh","timestamp":1584375074441},{"file_id":"15oMbXnMa4LDEMhPHBr3ga0xhJomMLhDo","timestamp":1584105762670},{"file_id":"1__NtYFNA3DxNB7LrUY13Bt8_frye3iWl","timestamp":1583445015203},{"file_id":"11jsQfqKeDU1Zk3nPykjWKwYhFmvJ1zJ-","timestamp":1575289898486}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"WDrFAwpFIpE0"},"source":["# **U-Net (2D)**\n","---\n","\n","U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.\n","\n"," **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the papers: \n","\n","**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.* published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)\n","\n","and \n","\n","**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019\n","(https://www.nature.com/articles/s41592-018-0261-2)\n","And source code found in: https://github.com/zhixuhao/unet by *Zhixuhao*\n","\n","**Please also cite this original paper when using or developing this notebook.** "]},{"cell_type":"markdown","metadata":{"id":"ABNu2p4stHeB"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"HVwncY_NvlYi"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif, ...\n"," - Training_target\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif\n"," - Training_target \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"JrGNzgEyxzGQ"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"wYoajeT54sQM"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"TpT6gbwURzrV","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"quzkzlRD45HF"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"eLwDxBnp4-bc","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"leK5kmgD5Ism"},"source":["# **2. Install U-Net dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"vOeLpQfT0QF1","cellView":"form"},"source":["Notebook_version = ['1.11.2']\n","\n","\n","#@markdown ##Play to install U-Net dependencies\n","\n","#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)\n","#only the data library needs to be additionally installed.\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","# print(tensorflow.__version__)\n","# print(\"Tensorflow enabled.\")\n","\n","\n","#!pip install keras==2.2.5\n","!pip install data\n","!pip install fpdf\n","# Keras imports\n","from keras import models\n","from keras.models import Model, load_model\n","from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D\n","from keras.optimizers import Adam\n","# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints\n","from keras.callbacks import ModelCheckpoint\n","from keras.callbacks import ReduceLROnPlateau\n","from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img\n","from keras import backend as keras\n","\n","# General import\n","from __future__ import print_function\n","import numpy as np\n","import pandas as pd\n","import os\n","import glob\n","from skimage import img_as_ubyte, io, transform\n","import matplotlib as mpl\n","from matplotlib import pyplot as plt\n","from matplotlib.pyplot import imread\n","from pathlib import Path\n","import shutil\n","import random\n","import time\n","import csv\n","import sys\n","from math import ceil\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","# Imports for QC\n","from PIL import Image\n","from scipy import signal\n","from scipy import ndimage\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","# from tqdm import tqdm\n","from tqdm.notebook import tqdm\n","\n","from sklearn.feature_extraction import image\n","from skimage import img_as_ubyte, io, transform\n","from skimage.util.shape import view_as_windows\n","\n","from datetime import datetime\n","\n","\n","# Suppressing some warnings\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","\n","\n","\n","def create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction):\n"," \"\"\"\n"," Function creates patches from the Training_source and Training_target images. \n"," The steps parameter indicates the offset between patches and, if integer, is the same in x and y.\n"," Saves all created patches in two new directories in the /content folder.\n","\n"," min_fraction is the minimum fraction of pixels that need to be foreground to be considered as a valid patch\n","\n"," Returns: - Two paths to where the patches are now saved\n"," \"\"\"\n"," DEBUG = False\n","\n"," Patch_source = os.path.join('/content','img_patches')\n"," Patch_target = os.path.join('/content','mask_patches')\n"," Patch_rejected = os.path.join('/content','rejected')\n"," \n","\n"," #Here we save the patches, in the /content directory as they will not usually be needed after training\n"," if os.path.exists(Patch_source):\n"," shutil.rmtree(Patch_source)\n"," if os.path.exists(Patch_target):\n"," shutil.rmtree(Patch_target)\n"," if os.path.exists(Patch_rejected):\n"," shutil.rmtree(Patch_rejected)\n","\n"," os.mkdir(Patch_source)\n"," os.mkdir(Patch_target)\n"," os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.\n"," \n"," patch_num = 0\n","\n"," for file in tqdm(os.listdir(Training_source)):\n","\n"," img = io.imread(os.path.join(Training_source, file))\n"," mask = io.imread(os.path.join(Training_target, file),as_gray=True)\n","\n"," if DEBUG:\n"," print(file)\n"," print(img.dtype)\n","\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n"," patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))\n","\n"," patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)\n"," patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)\n","\n"," if DEBUG:\n"," print(all_patches_img.shape)\n"," print(all_patches_img.dtype)\n","\n"," for i in range(patches_img.shape[0]):\n"," img_save_path = os.path.join(Patch_source,'patch_'+str(patch_num)+'.tif')\n"," mask_save_path = os.path.join(Patch_target,'patch_'+str(patch_num)+'.tif')\n"," patch_num += 1\n","\n"," # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images\n"," pixel_threshold_array = sorted(patches_mask[i].flatten())\n"," if pixel_threshold_array[int(round(len(pixel_threshold_array)*(1-min_fraction)))]>0:\n"," io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(patches_img[i])))\n"," io.imsave(mask_save_path, convert2Mask(normalizeMinMax(patches_mask[i]),0))\n"," else:\n"," io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_image.tif', img_as_ubyte(normalizeMinMax(patches_img[i])))\n"," io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_mask.tif', convert2Mask(normalizeMinMax(patches_mask[i]),0))\n","\n"," return Patch_source, Patch_target\n","\n","\n","def estimatePatchSize(data_path, max_width = 512, max_height = 512):\n","\n"," files = os.listdir(data_path)\n"," \n"," # Get the size of the first image found in the folder and initialise the variables to that\n"," n = 0 \n"," while os.path.isdir(os.path.join(data_path, files[n])):\n"," n += 1\n"," (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size\n","\n"," # Screen the size of all dataset to find the minimum image size\n"," for file in files:\n"," if not os.path.isdir(os.path.join(data_path, file)):\n"," (height, width) = Image.open(os.path.join(data_path, file)).size\n"," if width < width_min:\n"," width_min = width\n"," if height < height_min:\n"," height_min = height\n"," \n"," # Find the power of patches that will fit within the smallest dataset\n"," width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))\n","\n"," # Clip values at maximum permissible values\n"," if width_min > max_width:\n"," width_min = max_width\n","\n"," if height_min > max_height:\n"," height_min = max_height\n"," \n"," return (width_min, height_min)\n","\n","def fittingPowerOfTwo(number):\n"," n = 0\n"," while 2**n <= number:\n"," n += 1 \n"," return 2**(n-1)\n","\n","\n","def getClassWeights(Training_target_path):\n","\n"," Mask_dir_list = os.listdir(Training_target_path)\n"," number_of_dataset = len(Mask_dir_list)\n","\n"," class_count = np.zeros(2, dtype=int)\n"," for i in tqdm(range(number_of_dataset)):\n"," mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))\n"," mask = normalizeMinMax(mask)\n"," class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()\n"," class_count[1] += mask.sum()\n","\n"," n_samples = class_count.sum()\n"," n_classes = 2\n","\n"," class_weights = n_samples / (n_classes * class_count)\n"," return class_weights\n","\n","def weighted_binary_crossentropy(class_weights):\n","\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)\n"," weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]\n"," weighted_binary_crossentropy = weight_vector * binary_crossentropy\n","\n"," return keras.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","\n","def save_augment(datagen,orig_img,dir_augmented_data=\"/content/augment\"):\n"," \"\"\"\n"," Saves a subset of the augmented data for visualisation, by default in /content.\n","\n"," This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html\n"," \n"," \"\"\"\n"," try:\n"," os.mkdir(dir_augmented_data)\n"," except:\n"," ## if the preview folder exists, then remove\n"," ## the contents (pictures) in the folder\n"," for item in os.listdir(dir_augmented_data):\n"," os.remove(dir_augmented_data + \"/\" + item)\n","\n"," ## convert the original image to array\n"," x = img_to_array(orig_img)\n"," ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B\n"," #print(x.shape)\n"," x = x.reshape((1,) + x.shape)\n"," #print(x.shape)\n"," ## -------------------------- ##\n"," ## randomly generate pictures\n"," ## -------------------------- ##\n"," i = 0\n"," #We will just save 5 images,\n"," #but this can be changed, but note the visualisation in 3. currently uses 5.\n"," Nplot = 5\n"," for batch in datagen.flow(x,batch_size=1,\n"," save_to_dir=dir_augmented_data,\n"," save_format='tif',\n"," seed=42):\n"," i += 1\n"," if i > Nplot - 1:\n"," break\n","\n","# Generators\n","def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):\n"," '''\n"," Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same\n"," \n"," datagen: ImageDataGenerator \n"," subset: can take either 'training' or 'validation'\n"," '''\n"," seed = 1\n"," image_generator = image_datagen.flow_from_directory(\n"," os.path.dirname(image_folder_path),\n"," classes = [os.path.basename(image_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"bicubic\",\n"," seed = seed)\n"," \n"," mask_generator = mask_datagen.flow_from_directory(\n"," os.path.dirname(mask_folder_path),\n"," classes = [os.path.basename(mask_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"nearest\",\n"," seed = seed)\n"," \n"," this_generator = zip(image_generator, mask_generator)\n"," for (img,mask) in this_generator:\n"," # img,mask = adjustData(img,mask)\n"," yield (img,mask)\n","\n","\n","def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512)):\n"," image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)\n"," mask_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizeMinMax)\n","\n"," train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size)\n"," validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size)\n","\n"," return (train_datagen, validation_datagen)\n","\n","\n","# Normalization functions from Martin Weigert\n","def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","\n","\n","# Simple normalization to min/max fir the Mask\n","def normalizeMinMax(x, dtype=np.float32):\n"," x = x.astype(dtype,copy=False)\n"," x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))\n"," return x\n","\n","\n","# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. \n","def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, class_weights=np.ones(2)):\n"," inputs = Input(input_size)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)\n"," # Downsampling steps\n"," pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)\n"," \n"," if pooling_steps > 1:\n"," pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)\n","\n"," if pooling_steps > 2:\n"," pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)\n"," drop4 = Dropout(0.5)(conv4)\n"," \n"," if pooling_steps > 3:\n"," pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)\n"," drop5 = Dropout(0.5)(conv5)\n","\n"," #Upsampling steps\n"," up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))\n"," merge6 = concatenate([drop4,up6], axis = 3)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)\n"," \n"," if pooling_steps > 2:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))\n"," if pooling_steps > 3:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))\n"," merge7 = concatenate([conv3,up7], axis = 3)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)\n"," \n"," if pooling_steps > 1:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))\n"," if pooling_steps > 2:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))\n"," merge8 = concatenate([conv2,up8], axis = 3)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)\n"," \n"," if pooling_steps == 1:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))\n"," else:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'\n"," \n"," merge9 = concatenate([conv1,up9], axis = 3)\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv9 = Conv2D(2, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)\n","\n"," model = Model(inputs = inputs, outputs = conv10)\n","\n"," # model.compile(optimizer = Adam(lr = learning_rate), loss = 'binary_crossentropy', metrics = ['acc'])\n"," model.compile(optimizer = Adam(lr = learning_rate), loss = weighted_binary_crossentropy(class_weights))\n","\n","\n"," if verbose:\n"," model.summary()\n","\n"," if(pretrained_weights):\n"," \tmodel.load_weights(pretrained_weights);\n","\n"," return model\n","\n","\n","\n","def predict_as_tiles(Image_path, model):\n","\n"," # Read the data in and normalize\n"," Image_raw = io.imread(Image_path, as_gray = True)\n"," Image_raw = normalizePercentile(Image_raw)\n","\n"," # Get the patch size from the input layer of the model\n"," patch_size = model.layers[0].output_shape[1:3]\n","\n"," # Pad the image with zeros if any of its dimensions is smaller than the patch size\n"," if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:\n"," Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))\n"," Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw\n"," else:\n"," Image = Image_raw\n","\n"," # Calculate the number of patches in each dimension\n"," n_patch_in_width = ceil(Image.shape[0]/patch_size[0])\n"," n_patch_in_height = ceil(Image.shape[1]/patch_size[1])\n","\n"," prediction = np.zeros(Image.shape)\n","\n"," for x in range(n_patch_in_width):\n"," for y in range(n_patch_in_height):\n"," xi = patch_size[0]*x\n"," yi = patch_size[1]*y\n","\n"," # If the patch exceeds the edge of the image shift it back \n"," if xi+patch_size[0] >= Image.shape[0]:\n"," xi = Image.shape[0]-patch_size[0]\n","\n"," if yi+patch_size[1] >= Image.shape[1]:\n"," yi = Image.shape[1]-patch_size[1]\n"," \n"," # Extract and reshape the patch\n"," patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]\n"," patch = np.reshape(patch,patch.shape+(1,))\n"," patch = np.reshape(patch,(1,)+patch.shape)\n","\n"," # Get the prediction from the patch and paste it in the prediction in the right place\n"," predicted_patch = model.predict(patch, batch_size = 1)\n"," prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = np.squeeze(predicted_patch)\n","\n","\n"," return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]\n"," \n","\n","\n","\n","def saveResult(save_path, nparray, source_dir_list, prefix='', threshold=None):\n"," for (filename, image) in zip(source_dir_list, nparray):\n"," io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), img_as_ubyte(image)) # saving as unsigned 8-bit image\n"," \n"," # For masks, threshold the images and return 8 bit image\n"," if threshold is not None:\n"," mask = convert2Mask(image, threshold)\n"," io.imsave(os.path.join(save_path, prefix+'mask_'+os.path.splitext(filename)[0]+'.tif'), mask)\n","\n","\n","def convert2Mask(image, threshold):\n"," mask = img_as_ubyte(image, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n"," return mask\n","\n","\n","def getIoUvsThreshold(prediction_filepath, groud_truth_filepath):\n"," prediction = io.imread(prediction_filepath)\n"," ground_truth_image = img_as_ubyte(io.imread(groud_truth_filepath, as_gray=True), force_copy=True)\n","\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," # Convert to 8-bit for calculating the IoU\n"," mask = img_as_ubyte(prediction, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n","\n"," # Intersection over Union metric\n"," intersection = np.logical_and(ground_truth_image, np.squeeze(mask))\n"," union = np.logical_or(ground_truth_image, np.squeeze(mask))\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return (threshold_list, IoU_scores_list)\n","\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net and dependencies installed.')\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","print('Notebook version: '+Notebook_version[0])\n","\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","# Exporting requirements.txt for local run\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7hTKImff6Est"},"source":["# **3. Select your parameters and paths**\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"S74FbqV6PNNv"},"source":["##**3.1. Parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"3np5EpJF8_q2"},"source":[" **Paths for training data and models**\n","\n","**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (segmentation masks). Enter the path to the source and target images for training. **These should be located in the same parent folder.**\n","\n","**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).\n","\n","**`model_path`**: Enter the path of the folder where you want to save your model.\n","\n","**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**\n","\n"," **Select training parameters**\n","\n","**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**\n","\n","**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. The default value is calculated to ensure this. This behaviour can also be obtained by setting it to 0. Other values can be used for testing.\n","\n"," **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.\n","\n","**`min_fraction`:** Minimum fraction of pixels being foreground for a slected patch to be considered valid. It should be between 0 and 1.**Default value: 0.02** (2%)\n","\n"]},{"cell_type":"code","metadata":{"id":"7deNuPZd5d-B","cellView":"form"},"source":["# ------------- Initial user input ------------\n","#@markdown ###Path to training images:\n","Training_source = '' #@param {type:\"string\"}\n","Training_target = '' #@param {type:\"string\"}\n","\n","model_name = '' #@param {type:\"string\"}\n","model_path = '' #@param {type:\"string\"}\n","\n","#@markdown ###Training parameters:\n","#@markdown Number of epochs\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced parameters:\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"integer\"}\n","number_of_steps = 0#@param {type:\"number\"}\n","pooling_steps = 2 #@param [1,2,3,4]{type:\"raw\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","patch_width = 512#@param{type:\"number\"}\n","patch_height = 512#@param{type:\"number\"}\n","min_fraction = 0.02#@param{type:\"number\"}\n","\n","\n","# ------------- Initialising folder, variables and failsafes ------------\n","# Create the folders where to save the model and the QC\n","full_model_path = os.path.join(model_path, model_name)\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 4\n"," pooling_steps = 2\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n"," patch_width, patch_height = estimatePatchSize(Training_source)\n"," min_fraction = 0.02\n","\n","\n","#The create_patches function will create the two folders below\n","# Patch_source = '/content/img_patches'\n","# Patch_target = '/content/mask_patches'\n","print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')\n","\n","#Create patches\n","print('Creating patches...')\n","Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction)\n","\n","number_of_training_dataset = len(os.listdir(Patch_source))\n","print('Total number of valid patches: '+str(number_of_training_dataset))\n","\n","if Use_Default_Advanced_Parameters or number_of_steps == 0:\n"," number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)\n","print('Number of steps: '+str(number_of_steps))\n","\n","# Calculate the number of steps to use for validation\n","validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))\n","\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = False\n","# Build the default dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = 0.,\n"," height_shift_range = 0.,\n"," rotation_range = 0., #90\n"," zoom_range = 0.,\n"," shear_range = 0.,\n"," horizontal_flip = False,\n"," vertical_flip = False,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","# ------------- Display ------------\n","\n","#if not os.path.exists('/content/img_patches/'):\n","random_choice = random.choice(os.listdir(Patch_source))\n","x = io.imread(os.path.join(Patch_source, random_choice))\n","\n","#os.chdir(Training_target)\n","y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest',cmap='gray')\n","plt.title('Training image patch')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest',cmap='gray')\n","plt.title('Training mask patch')\n","plt.axis('off');\n","\n","plt.savefig('/content/TrainingDataExample_Unet2D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"V9UCjlLJ5Rfc"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.\n","\n"," The augmentation options below are to be used as follows:\n","\n","* **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**\n","* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**\n","* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**\n","* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**\n","* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**"]},{"cell_type":"code","metadata":{"id":"i-PahNX94-pl","cellView":"form"},"source":["#@markdown ##**Augmentation options**\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," if Use_Default_Augmentation_Parameters:\n"," horizontal_shift = 10 \n"," vertical_shift = 20 \n"," zoom_range = 10\n"," shear_range = 10\n"," horizontal_flip = True\n"," vertical_flip = True\n"," rotation_range = 180\n","#@markdown ###If you are not using the default settings, please provide the values below:\n","\n","#@markdown ###**Image shift, zoom, shear and flip (%)**\n"," else:\n"," horizontal_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," vertical_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," zoom_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," shear_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," horizontal_flip = True #@param {type:\"boolean\"}\n"," vertical_flip = True #@param {type:\"boolean\"}\n","\n","#@markdown ###**Rotate image within angle range (degrees):**\n"," rotation_range = 180 #@param {type:\"slider\", min:0, max:180, step:1}\n","\n","#given behind the # are the default values for each parameter.\n","\n","else:\n"," horizontal_shift = 0 \n"," vertical_shift = 0 \n"," zoom_range = 0\n"," shear_range = 0\n"," horizontal_flip = False\n"," vertical_flip = False\n"," rotation_range = 0\n","\n","\n","# Build the dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = horizontal_shift/100.,\n"," height_shift_range = vertical_shift/100.,\n"," rotation_range = rotation_range, #90\n"," zoom_range = zoom_range/100.,\n"," shear_range = shear_range/100.,\n"," horizontal_flip = horizontal_flip,\n"," vertical_flip = vertical_flip,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","\n","\n","# ------------- Display ------------\n","dir_augmented_data_imgs=\"/content/augment_img\"\n","dir_augmented_data_masks=\"/content/augment_mask\"\n","random_choice = random.choice(os.listdir(Patch_source))\n","orig_img = load_img(os.path.join(Patch_source,random_choice))\n","orig_mask = load_img(os.path.join(Patch_target,random_choice))\n","\n","augment_view = ImageDataGenerator(**data_gen_args)\n","\n","if Use_Data_augmentation:\n"," print(\"Parameters enabled\")\n"," print(\"Here is what a subset of your augmentations looks like:\")\n"," save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)\n"," save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)\n","\n"," fig = plt.figure(figsize=(15, 7))\n"," fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)\n","\n"," \n"," ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[]) \n"," new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))\n"," ax.imshow(new_img)\n"," ax.set_title('Original Image')\n"," i = 2\n"," for imgnm in os.listdir(dir_augmented_data_imgs):\n"," ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) \n"," img = load_img(dir_augmented_data_imgs + \"/\" + imgnm)\n"," ax.imshow(img)\n"," i += 1\n","\n"," ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[]) \n"," new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))\n"," ax.imshow(new_mask)\n"," ax.set_title('Original Mask')\n"," j=2\n"," for imgnm in os.listdir(dir_augmented_data_masks):\n"," ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) \n"," mask = load_img(dir_augmented_data_masks + \"/\" + imgnm)\n"," ax.imshow(mask)\n"," j += 1\n"," plt.show()\n","\n","else:\n"," print(\"No augmentation will be used\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7vFEIHbNAuOs"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"RfR9UyKAAulw","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the UNET_Model_from_\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(R+'WARNING: pretrained model does not exist')\n"," Use_pretrained_model = False\n"," \n","\n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(R+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"94FX4wzE8w1W"},"source":["# **4. Train the network**\n","---\n","####**Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time. "]},{"cell_type":"markdown","metadata":{"id":"tlTDGcmDDHDe"},"source":["## **4.1. Prepare model for training**\n","---"]},{"cell_type":"code","metadata":{"id":"ezFy_mpz_op4","cellView":"form"},"source":["#@markdown ##Play this cell to prepare the model for training\n","\n","\n","# ------------------ Set the generators, model and logger ------------------\n","# This will take the image size and set that as a patch size (arguable...)\n","# Read image size (without actuall reading the data)\n","\n","(train_datagen, validation_datagen) = prepareGenerators(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height))\n","\n","\n","# This modelcheckpoint will only save the best model from the validation loss point of view\n","model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)\n","\n","print('Getting class weights...')\n","class_weights = getClassWeights(Training_target)\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","else:\n"," h5_file_path = None\n","\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Reduce learning rate on plateau ------------------------\n","\n","reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, mode='auto',\n"," patience=10, min_lr=0)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Define the model\n","model = unet(pretrained_weights = h5_file_path, \n"," input_size = (patch_width,patch_height,1), \n"," pooling_steps = pooling_steps, \n"," learning_rate = initial_learning_rate, \n"," class_weights = class_weights)\n","\n","config_model= model.optimizer.get_config()\n","print(config_model)\n","\n","\n","# ------------------ Failsafes ------------------\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)\n"," shutil.rmtree(full_model_path)\n","\n","os.makedirs(full_model_path)\n","os.makedirs(os.path.join(full_model_path,'Quality Control'))\n","\n","\n","# ------------------ Display ------------------\n","print('---------------------------- Main training parameters ----------------------------')\n","print('Number of epochs: '+str(number_of_epochs))\n","print('Batch size: '+str(batch_size))\n","print('Number of training dataset: '+str(number_of_training_dataset))\n","print('Number of training steps: '+str(number_of_steps))\n","print('Number of validation steps: '+str(validation_steps))\n","print('---------------------------- ------------------------ ----------------------------')\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"urpQ9UM-6NBE"},"source":["## **4.2. Start Training**\n","---\n","\n","####**Be patient**. Please be patient, this may take a while. But the verbose allow you to estimate how fast it's training and how long it'll take. While it's training, please make sure that the computer is not powering down due to inactivity, otherwise this will interupt the runtime."]},{"cell_type":"code","metadata":{"id":"sMyCENd29TKz","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","# history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs=epochs, callbacks=[model_checkpoint,csv_log], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs = number_of_epochs, callbacks=[model_checkpoint, reduce_lr], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","\n","# Save the last model\n","model.save(os.path.join(full_model_path, 'weights_last.hdf5'))\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n"," \n","\n","\n","# Displaying the time elapsed for training\n","print(\"------------------------------------------\")\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\", hour, \"hour(s)\", mins,\"min(s)\",round(sec),\"sec(s)\")\n","print(\"------------------------------------------\")\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'U-Net 2D'\n","\n","\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and method:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","loss = str(model.loss)[str(model.loss).find('function')+len('function'):str(model.loss).find('.<')]\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(180, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=1)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if rotation_range != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if horizontal_flip == True or vertical_flip == True:\n"," aug_text = aug_text+'\\n- flipping'\n"," if zoom_range != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if horizontal_shift != 0 or vertical_shift != 0:\n"," aug_text = aug_text+'\\n- shifting'\n"," if shear_range != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
pooling_steps{6}
min_fraction{7}
\n","\"\"\".format(number_of_epochs, str(patch_width)+'x'+str(patch_height), batch_size, number_of_steps, percentage_validation, initial_learning_rate, pooling_steps, min_fraction)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_Unet2D.png').shape\n","pdf.image('/content/TrainingDataExample_Unet2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","# if Use_Data_augmentation:\n","# ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n","# pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n","print('------------------------------')\n","print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"LWaFk0JNda-N"},"source":["## **4.3. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"mEMcFNHZdmTz"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"X11zGW0Ldu-z","cellView":"form"},"source":["#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","\n","full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n","if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"pkJyRzWJCrKG"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qul6BpaX1GqS","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","epochNumber = []\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'),bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"h33P0C2geqZu"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.\n","\n"," The results for all QC examples can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n","\n","### **Thresholds for image masks**\n","\n"," Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**"]},{"cell_type":"code","metadata":{"id":"Tpqjvwv2zug-","cellView":"form"},"source":["# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","# ------------- Initialise folders ------------\n","# Create a quality control/Prediction Folder\n","prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n","if os.path.exists(prediction_QC_folder):\n"," shutil.rmtree(prediction_QC_folder)\n","\n","os.makedirs(prediction_QC_folder)\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model\n","unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Source_QC_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","#-----------------------------Calculate Metrics----------------------------------------#\n","\n","f = plt.figure(figsize=((5,5)))\n","\n","with open(os.path.join(full_QC_model_path,'Quality Control', 'QC_metrics_'+QC_model_name+'.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"File name\",\"IoU\", \"IoU-optimised threshold\"]) \n","\n"," # Initialise the lists \n"," filename_list = []\n"," best_threshold_list = []\n"," best_IoU_score_list = []\n","\n"," for filename in os.listdir(Source_QC_folder):\n","\n"," if not os.path.isdir(os.path.join(Source_QC_folder, filename)):\n"," print('Running QC on: '+filename)\n"," test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)\n","\n"," (threshold_list, iou_scores_per_threshold) = getIoUvsThreshold(os.path.join(prediction_QC_folder, prediction_prefix+filename), os.path.join(Target_QC_folder, filename))\n"," plt.plot(threshold_list,iou_scores_per_threshold, label=filename)\n","\n"," # Here we find which threshold yielded the highest IoU score for image n.\n"," best_IoU_score = max(iou_scores_per_threshold)\n"," best_threshold = iou_scores_per_threshold.index(best_IoU_score)\n","\n"," # Write the results in the CSV file\n"," writer.writerow([filename, str(best_IoU_score), str(best_threshold)])\n","\n"," # Here we append the best threshold and score to the lists\n"," filename_list.append(filename)\n"," best_IoU_score_list.append(best_IoU_score)\n"," best_threshold_list.append(best_threshold)\n","\n","# Display the IoV vs Threshold plot\n","plt.title('IoU vs. Threshold')\n","plt.ylabel('Threshold value')\n","plt.xlabel('IoU')\n","plt.legend()\n","plt.savefig(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = filename_list)\n","pdResults[\"IoU\"] = best_IoU_score_list\n","pdResults[\"IoU-optimised threshold\"] = best_threshold_list\n","\n","average_best_threshold = sum(best_threshold_list)/len(best_threshold_list)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(Source_QC_folder)):\n"," \n"," plt.figure(figsize=(25,5))\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)\n"," plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))\n"," test_prediction_mask = np.empty_like(test_prediction)\n"," test_prediction_mask[test_prediction > average_best_threshold] = 255\n"," test_prediction_mask[test_prediction <= average_best_threshold] = 0\n"," plt.imshow(test_prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_image, cmap='Greens')\n"," plt.imshow(test_prediction_mask, alpha=0.5, cmap='Purples')\n"," metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file][\"IoU\"],3)) + ' T: ' + str(round(pdResults.loc[file][\"IoU-optimised threshold\"])) + ')'\n"," plt.title(metrics_title)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","print('--------------------------------------------------------------')\n","print('Best average threshold is: '+str(round(average_best_threshold)))\n","print('--------------------------------------------------------------')\n","\n","pdResults.head()\n","\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Unet 2D'\n","\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')\n","pdf.ln(2)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Threshold Optimisation', ln=1, align='L')\n","#pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png', x = 11, y = None, w = round(exp_size[1]/6), h = round(exp_size[0]/7))\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," IoU = header[1]\n"," IoU_OptThresh = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,IoU,IoU_OptThresh)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," image = row[0]\n"," IoU = row[1]\n"," IoU_OptThresh = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(IoU),3)),str(round(float(IoU_OptThresh),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","print('------------------------------')\n","print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"gofmRsLP96O8"},"source":["# **6. Using the trained model**\n","\n","---\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"Pv_v1Ru2OJkU"},"source":["## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n"," Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.\n","\n"," **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters."]},{"cell_type":"code","metadata":{"id":"FJAe55ZoOJGs","cellView":"form"},"source":["\n","\n","# ------------- Initial user input ------------\n","#@markdown ###Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","Data_folder = '' #@param {type:\"string\"}\n","Results_folder = '' #@param {type:\"string\"}\n","\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","# ------------- Failsafes ------------\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model and prepare generator\n","\n","\n","\n","unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Data_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))\n"," # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","\n","\n","def show_prediction_mask(file=os.listdir(Data_folder), threshold=(0,255,1)):\n","\n"," plt.figure(figsize=(18,6))\n"," # Wide-field\n"," plt.subplot(1,3,1)\n"," plt.axis('off')\n"," img_Source = plt.imread(os.path.join(Data_folder, file))\n"," plt.imshow(img_Source, cmap='gray')\n"," plt.title('Source image',fontsize=15)\n"," # Prediction\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))\n"," plt.imshow(img_Prediction, cmap='gray')\n"," plt.title('Prediction',fontsize=15)\n","\n"," # Thresholded mask\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," img_Mask = convert2Mask(img_Prediction, threshold)\n"," plt.imshow(img_Mask, cmap='gray')\n"," plt.title('Mask (Threshold: '+str(round(threshold))+')',fontsize=15)\n","\n","\n","interact(show_prediction_mask, continuous_update=False);\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"su-Mo2POVpja"},"source":["## **6.2. Export results as masks**\n","---\n"]},{"cell_type":"code","metadata":{"id":"iC_B_9lxNUny","cellView":"form"},"source":["\n","# @markdown #Play this cell to save results as masks with the chosen threshold\n","threshold = 120#@param {type:\"number\"}\n","\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=threshold)\n","print('-------------------')\n","print('Masks were saved in: '+Results_folder)\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wYmwCQKjYsJ7"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"sCXzzvnh2_rc"},"source":["#**Thank you for using U-Net!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **U-Net (2D)**\n","---\n","\n","U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.\n","\n"," **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the papers: \n","\n","**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.* published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)\n","\n","and \n","\n","**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019\n","(https://www.nature.com/articles/s41592-018-0261-2)\n","And source code found in: https://github.com/zhixuhao/unet by *Zhixuhao*\n","\n","**Please also cite this original paper when using or developing this notebook.** "]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif, ...\n"," - Training_target\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif\n"," - Training_target \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install U-Net dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","#@markdown ##Play to install U-Net dependencies\n","\n","#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)\n","#only the data library needs to be additionally installed.\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","# print(tensorflow.__version__)\n","# print(\"Tensorflow enabled.\")\n","\n","\n","#!pip install keras==2.2.5\n","!pip install data\n","!pip install fpdf\n","# Keras imports\n","from keras import models\n","from keras.models import Model, load_model\n","from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D\n","from keras.optimizers import Adam\n","# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints\n","from keras.callbacks import ModelCheckpoint\n","from keras.callbacks import ReduceLROnPlateau\n","from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img\n","from keras import backend as keras\n","\n","# General import\n","from __future__ import print_function\n","import numpy as np\n","import pandas as pd\n","import os\n","import glob\n","from skimage import img_as_ubyte, io, transform\n","import matplotlib as mpl\n","from matplotlib import pyplot as plt\n","from matplotlib.pyplot import imread\n","from pathlib import Path\n","import shutil\n","import random\n","import time\n","import csv\n","import sys\n","from math import ceil\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","# Imports for QC\n","from PIL import Image\n","from scipy import signal\n","from scipy import ndimage\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","# from tqdm import tqdm\n","from tqdm.notebook import tqdm\n","\n","from sklearn.feature_extraction import image\n","from skimage import img_as_ubyte, io, transform\n","from skimage.util.shape import view_as_windows\n","\n","from datetime import datetime\n","\n","\n","# Suppressing some warnings\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","\n","\n","\n","def create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction):\n"," \"\"\"\n"," Function creates patches from the Training_source and Training_target images. \n"," The steps parameter indicates the offset between patches and, if integer, is the same in x and y.\n"," Saves all created patches in two new directories in the /content folder.\n","\n"," min_fraction is the minimum fraction of pixels that need to be foreground to be considered as a valid patch\n","\n"," Returns: - Two paths to where the patches are now saved\n"," \"\"\"\n"," DEBUG = False\n","\n"," Patch_source = os.path.join('/content','img_patches')\n"," Patch_target = os.path.join('/content','mask_patches')\n"," Patch_rejected = os.path.join('/content','rejected')\n"," \n","\n"," #Here we save the patches, in the /content directory as they will not usually be needed after training\n"," if os.path.exists(Patch_source):\n"," shutil.rmtree(Patch_source)\n"," if os.path.exists(Patch_target):\n"," shutil.rmtree(Patch_target)\n"," if os.path.exists(Patch_rejected):\n"," shutil.rmtree(Patch_rejected)\n","\n"," os.mkdir(Patch_source)\n"," os.mkdir(Patch_target)\n"," os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.\n"," \n"," patch_num = 0\n","\n"," for file in tqdm(os.listdir(Training_source)):\n","\n"," img = io.imread(os.path.join(Training_source, file))\n"," mask = io.imread(os.path.join(Training_target, file),as_gray=True)\n","\n"," if DEBUG:\n"," print(file)\n"," print(img.dtype)\n","\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n"," patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))\n","\n"," patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)\n"," patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)\n","\n"," if DEBUG:\n"," print(all_patches_img.shape)\n"," print(all_patches_img.dtype)\n","\n"," for i in range(patches_img.shape[0]):\n"," img_save_path = os.path.join(Patch_source,'patch_'+str(patch_num)+'.tif')\n"," mask_save_path = os.path.join(Patch_target,'patch_'+str(patch_num)+'.tif')\n"," patch_num += 1\n","\n"," # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images\n"," pixel_threshold_array = sorted(patches_mask[i].flatten())\n"," if pixel_threshold_array[int(round(len(pixel_threshold_array)*(1-min_fraction)))]>0:\n"," io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(patches_img[i])))\n"," io.imsave(mask_save_path, convert2Mask(normalizeMinMax(patches_mask[i]),0))\n"," else:\n"," io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_image.tif', img_as_ubyte(normalizeMinMax(patches_img[i])))\n"," io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_mask.tif', convert2Mask(normalizeMinMax(patches_mask[i]),0))\n","\n"," return Patch_source, Patch_target\n","\n","\n","def estimatePatchSize(data_path, max_width = 512, max_height = 512):\n","\n"," files = os.listdir(data_path)\n"," \n"," # Get the size of the first image found in the folder and initialise the variables to that\n"," n = 0 \n"," while os.path.isdir(os.path.join(data_path, files[n])):\n"," n += 1\n"," (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size\n","\n"," # Screen the size of all dataset to find the minimum image size\n"," for file in files:\n"," if not os.path.isdir(os.path.join(data_path, file)):\n"," (height, width) = Image.open(os.path.join(data_path, file)).size\n"," if width < width_min:\n"," width_min = width\n"," if height < height_min:\n"," height_min = height\n"," \n"," # Find the power of patches that will fit within the smallest dataset\n"," width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))\n","\n"," # Clip values at maximum permissible values\n"," if width_min > max_width:\n"," width_min = max_width\n","\n"," if height_min > max_height:\n"," height_min = max_height\n"," \n"," return (width_min, height_min)\n","\n","def fittingPowerOfTwo(number):\n"," n = 0\n"," while 2**n <= number:\n"," n += 1 \n"," return 2**(n-1)\n","\n","\n","def getClassWeights(Training_target_path):\n","\n"," Mask_dir_list = os.listdir(Training_target_path)\n"," number_of_dataset = len(Mask_dir_list)\n","\n"," class_count = np.zeros(2, dtype=int)\n"," for i in tqdm(range(number_of_dataset)):\n"," mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))\n"," mask = normalizeMinMax(mask)\n"," class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()\n"," class_count[1] += mask.sum()\n","\n"," n_samples = class_count.sum()\n"," n_classes = 2\n","\n"," class_weights = n_samples / (n_classes * class_count)\n"," return class_weights\n","\n","def weighted_binary_crossentropy(class_weights):\n","\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)\n"," weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]\n"," weighted_binary_crossentropy = weight_vector * binary_crossentropy\n","\n"," return keras.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","\n","def save_augment(datagen,orig_img,dir_augmented_data=\"/content/augment\"):\n"," \"\"\"\n"," Saves a subset of the augmented data for visualisation, by default in /content.\n","\n"," This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html\n"," \n"," \"\"\"\n"," try:\n"," os.mkdir(dir_augmented_data)\n"," except:\n"," ## if the preview folder exists, then remove\n"," ## the contents (pictures) in the folder\n"," for item in os.listdir(dir_augmented_data):\n"," os.remove(dir_augmented_data + \"/\" + item)\n","\n"," ## convert the original image to array\n"," x = img_to_array(orig_img)\n"," ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B\n"," #print(x.shape)\n"," x = x.reshape((1,) + x.shape)\n"," #print(x.shape)\n"," ## -------------------------- ##\n"," ## randomly generate pictures\n"," ## -------------------------- ##\n"," i = 0\n"," #We will just save 5 images,\n"," #but this can be changed, but note the visualisation in 3. currently uses 5.\n"," Nplot = 5\n"," for batch in datagen.flow(x,batch_size=1,\n"," save_to_dir=dir_augmented_data,\n"," save_format='tif',\n"," seed=42):\n"," i += 1\n"," if i > Nplot - 1:\n"," break\n","\n","# Generators\n","def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):\n"," '''\n"," Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same\n"," \n"," datagen: ImageDataGenerator \n"," subset: can take either 'training' or 'validation'\n"," '''\n"," seed = 1\n"," image_generator = image_datagen.flow_from_directory(\n"," os.path.dirname(image_folder_path),\n"," classes = [os.path.basename(image_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"bicubic\",\n"," seed = seed)\n"," \n"," mask_generator = mask_datagen.flow_from_directory(\n"," os.path.dirname(mask_folder_path),\n"," classes = [os.path.basename(mask_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"nearest\",\n"," seed = seed)\n"," \n"," this_generator = zip(image_generator, mask_generator)\n"," for (img,mask) in this_generator:\n"," # img,mask = adjustData(img,mask)\n"," yield (img,mask)\n","\n","\n","def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512)):\n"," image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)\n"," mask_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizeMinMax)\n","\n"," train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size)\n"," validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size)\n","\n"," return (train_datagen, validation_datagen)\n","\n","\n","# Normalization functions from Martin Weigert\n","def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","\n","\n","# Simple normalization to min/max fir the Mask\n","def normalizeMinMax(x, dtype=np.float32):\n"," x = x.astype(dtype,copy=False)\n"," x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))\n"," return x\n","\n","\n","# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. \n","def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, class_weights=np.ones(2)):\n"," inputs = Input(input_size)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)\n"," # Downsampling steps\n"," pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)\n"," \n"," if pooling_steps > 1:\n"," pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)\n","\n"," if pooling_steps > 2:\n"," pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)\n"," drop4 = Dropout(0.5)(conv4)\n"," \n"," if pooling_steps > 3:\n"," pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)\n"," drop5 = Dropout(0.5)(conv5)\n","\n"," #Upsampling steps\n"," up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))\n"," merge6 = concatenate([drop4,up6], axis = 3)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)\n"," \n"," if pooling_steps > 2:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))\n"," if pooling_steps > 3:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))\n"," merge7 = concatenate([conv3,up7], axis = 3)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)\n"," \n"," if pooling_steps > 1:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))\n"," if pooling_steps > 2:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))\n"," merge8 = concatenate([conv2,up8], axis = 3)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)\n"," \n"," if pooling_steps == 1:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))\n"," else:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'\n"," \n"," merge9 = concatenate([conv1,up9], axis = 3)\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv9 = Conv2D(2, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)\n","\n"," model = Model(inputs = inputs, outputs = conv10)\n","\n"," # model.compile(optimizer = Adam(lr = learning_rate), loss = 'binary_crossentropy', metrics = ['acc'])\n"," model.compile(optimizer = Adam(lr = learning_rate), loss = weighted_binary_crossentropy(class_weights))\n","\n","\n"," if verbose:\n"," model.summary()\n","\n"," if(pretrained_weights):\n"," \tmodel.load_weights(pretrained_weights);\n","\n"," return model\n","\n","\n","\n","def predict_as_tiles(Image_path, model):\n","\n"," # Read the data in and normalize\n"," Image_raw = io.imread(Image_path, as_gray = True)\n"," Image_raw = normalizePercentile(Image_raw)\n","\n"," # Get the patch size from the input layer of the model\n"," patch_size = model.layers[0].output_shape[1:3]\n","\n"," # Pad the image with zeros if any of its dimensions is smaller than the patch size\n"," if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:\n"," Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))\n"," Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw\n"," else:\n"," Image = Image_raw\n","\n"," # Calculate the number of patches in each dimension\n"," n_patch_in_width = ceil(Image.shape[0]/patch_size[0])\n"," n_patch_in_height = ceil(Image.shape[1]/patch_size[1])\n","\n"," prediction = np.zeros(Image.shape)\n","\n"," for x in range(n_patch_in_width):\n"," for y in range(n_patch_in_height):\n"," xi = patch_size[0]*x\n"," yi = patch_size[1]*y\n","\n"," # If the patch exceeds the edge of the image shift it back \n"," if xi+patch_size[0] >= Image.shape[0]:\n"," xi = Image.shape[0]-patch_size[0]\n","\n"," if yi+patch_size[1] >= Image.shape[1]:\n"," yi = Image.shape[1]-patch_size[1]\n"," \n"," # Extract and reshape the patch\n"," patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]\n"," patch = np.reshape(patch,patch.shape+(1,))\n"," patch = np.reshape(patch,(1,)+patch.shape)\n","\n"," # Get the prediction from the patch and paste it in the prediction in the right place\n"," predicted_patch = model.predict(patch, batch_size = 1)\n"," prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = np.squeeze(predicted_patch)\n","\n","\n"," return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]\n"," \n","\n","\n","\n","def saveResult(save_path, nparray, source_dir_list, prefix='', threshold=None):\n"," for (filename, image) in zip(source_dir_list, nparray):\n"," io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), img_as_ubyte(image)) # saving as unsigned 8-bit image\n"," \n"," # For masks, threshold the images and return 8 bit image\n"," if threshold is not None:\n"," mask = convert2Mask(image, threshold)\n"," io.imsave(os.path.join(save_path, prefix+'mask_'+os.path.splitext(filename)[0]+'.tif'), mask)\n","\n","\n","def convert2Mask(image, threshold):\n"," mask = img_as_ubyte(image, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n"," return mask\n","\n","\n","def getIoUvsThreshold(prediction_filepath, groud_truth_filepath):\n"," prediction = io.imread(prediction_filepath)\n"," ground_truth_image = img_as_ubyte(io.imread(groud_truth_filepath, as_gray=True), force_copy=True)\n","\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," # Convert to 8-bit for calculating the IoU\n"," mask = img_as_ubyte(prediction, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n","\n"," # Intersection over Union metric\n"," intersection = np.logical_and(ground_truth_image, np.squeeze(mask))\n"," union = np.logical_or(ground_truth_image, np.squeeze(mask))\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return (threshold_list, IoU_scores_list)\n","\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net and dependencies installed.')\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","print('Notebook version: '+Notebook_version[0])\n","\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'U-Net 2D'\n","\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n"," loss = str(model.loss)[str(model.loss).find('function')+len('function'):str(model.loss).find('.<')]\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(180, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=1)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if rotation_range != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if horizontal_flip == True or vertical_flip == True:\n"," aug_text = aug_text+'\\n- flipping'\n"," if zoom_range != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if horizontal_shift != 0 or vertical_shift != 0:\n"," aug_text = aug_text+'\\n- shifting'\n"," if shear_range != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
pooling_steps{6}
min_fraction{7}
\n"," \"\"\".format(number_of_epochs, str(patch_width)+'x'+str(patch_height), batch_size, number_of_steps, percentage_validation, initial_learning_rate, pooling_steps, min_fraction)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Unet2D.png').shape\n"," pdf.image('/content/TrainingDataExample_Unet2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Unet 2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Threshold Optimisation', ln=1, align='L')\n"," #pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png', x = 11, y = None, w = round(exp_size[1]/6), h = round(exp_size[0]/7))\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," IoU = header[1]\n"," IoU_OptThresh = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,IoU,IoU_OptThresh)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," image = row[0]\n"," IoU = row[1]\n"," IoU_OptThresh = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(IoU),3)),str(round(float(IoU_OptThresh),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n"," print('------------------------------')\n"," print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","\n","# Exporting requirements.txt for local run\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training data and models**\n","\n","**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (segmentation masks). Enter the path to the source and target images for training. **These should be located in the same parent folder.**\n","\n","**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).\n","\n","**`model_path`**: Enter the path of the folder where you want to save your model.\n","\n","**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**\n","\n"," **Select training parameters**\n","\n","**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**\n","\n","**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. The default value is calculated to ensure this. This behaviour can also be obtained by setting it to 0. Other values can be used for testing.\n","\n"," **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.\n","\n","**`min_fraction`:** Minimum fraction of pixels being foreground for a slected patch to be considered valid. It should be between 0 and 1.**Default value: 0.02** (2%)\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# ------------- Initial user input ------------\n","#@markdown ###Path to training images:\n","Training_source = '' #@param {type:\"string\"}\n","Training_target = '' #@param {type:\"string\"}\n","\n","model_name = '' #@param {type:\"string\"}\n","model_path = '' #@param {type:\"string\"}\n","\n","#@markdown ###Training parameters:\n","#@markdown Number of epochs\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced parameters:\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"integer\"}\n","number_of_steps = 0#@param {type:\"number\"}\n","pooling_steps = 2 #@param [1,2,3,4]{type:\"raw\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","patch_width = 512#@param{type:\"number\"}\n","patch_height = 512#@param{type:\"number\"}\n","min_fraction = 0.02#@param{type:\"number\"}\n","\n","\n","# ------------- Initialising folder, variables and failsafes ------------\n","# Create the folders where to save the model and the QC\n","full_model_path = os.path.join(model_path, model_name)\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 4\n"," pooling_steps = 2\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n"," patch_width, patch_height = estimatePatchSize(Training_source)\n"," min_fraction = 0.02\n","\n","\n","#The create_patches function will create the two folders below\n","# Patch_source = '/content/img_patches'\n","# Patch_target = '/content/mask_patches'\n","print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')\n","\n","#Create patches\n","print('Creating patches...')\n","Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction)\n","\n","number_of_training_dataset = len(os.listdir(Patch_source))\n","print('Total number of valid patches: '+str(number_of_training_dataset))\n","\n","if Use_Default_Advanced_Parameters or number_of_steps == 0:\n"," number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)\n","print('Number of steps: '+str(number_of_steps))\n","\n","# Calculate the number of steps to use for validation\n","validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))\n","\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = False\n","# Build the default dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = 0.,\n"," height_shift_range = 0.,\n"," rotation_range = 0., #90\n"," zoom_range = 0.,\n"," shear_range = 0.,\n"," horizontal_flip = False,\n"," vertical_flip = False,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","# ------------- Display ------------\n","\n","#if not os.path.exists('/content/img_patches/'):\n","random_choice = random.choice(os.listdir(Patch_source))\n","x = io.imread(os.path.join(Patch_source, random_choice))\n","\n","#os.chdir(Training_target)\n","y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest',cmap='gray')\n","plt.title('Training image patch')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest',cmap='gray')\n","plt.title('Training mask patch')\n","plt.axis('off');\n","\n","plt.savefig('/content/TrainingDataExample_Unet2D.png',bbox_inches='tight',pad_inches=0)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.\n","\n"," The augmentation options below are to be used as follows:\n","\n","* **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**\n","* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**\n","* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**\n","* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**\n","* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#@markdown ##**Augmentation options**\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," if Use_Default_Augmentation_Parameters:\n"," horizontal_shift = 10 \n"," vertical_shift = 20 \n"," zoom_range = 10\n"," shear_range = 10\n"," horizontal_flip = True\n"," vertical_flip = True\n"," rotation_range = 180\n","#@markdown ###If you are not using the default settings, please provide the values below:\n","\n","#@markdown ###**Image shift, zoom, shear and flip (%)**\n"," else:\n"," horizontal_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," vertical_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," zoom_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," shear_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," horizontal_flip = True #@param {type:\"boolean\"}\n"," vertical_flip = True #@param {type:\"boolean\"}\n","\n","#@markdown ###**Rotate image within angle range (degrees):**\n"," rotation_range = 180 #@param {type:\"slider\", min:0, max:180, step:1}\n","\n","#given behind the # are the default values for each parameter.\n","\n","else:\n"," horizontal_shift = 0 \n"," vertical_shift = 0 \n"," zoom_range = 0\n"," shear_range = 0\n"," horizontal_flip = False\n"," vertical_flip = False\n"," rotation_range = 0\n","\n","\n","# Build the dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = horizontal_shift/100.,\n"," height_shift_range = vertical_shift/100.,\n"," rotation_range = rotation_range, #90\n"," zoom_range = zoom_range/100.,\n"," shear_range = shear_range/100.,\n"," horizontal_flip = horizontal_flip,\n"," vertical_flip = vertical_flip,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","\n","\n","# ------------- Display ------------\n","dir_augmented_data_imgs=\"/content/augment_img\"\n","dir_augmented_data_masks=\"/content/augment_mask\"\n","random_choice = random.choice(os.listdir(Patch_source))\n","orig_img = load_img(os.path.join(Patch_source,random_choice))\n","orig_mask = load_img(os.path.join(Patch_target,random_choice))\n","\n","augment_view = ImageDataGenerator(**data_gen_args)\n","\n","if Use_Data_augmentation:\n"," print(\"Parameters enabled\")\n"," print(\"Here is what a subset of your augmentations looks like:\")\n"," save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)\n"," save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)\n","\n"," fig = plt.figure(figsize=(15, 7))\n"," fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)\n","\n"," \n"," ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[]) \n"," new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))\n"," ax.imshow(new_img)\n"," ax.set_title('Original Image')\n"," i = 2\n"," for imgnm in os.listdir(dir_augmented_data_imgs):\n"," ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) \n"," img = load_img(dir_augmented_data_imgs + \"/\" + imgnm)\n"," ax.imshow(img)\n"," i += 1\n","\n"," ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[]) \n"," new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))\n"," ax.imshow(new_mask)\n"," ax.set_title('Original Mask')\n"," j=2\n"," for imgnm in os.listdir(dir_augmented_data_masks):\n"," ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) \n"," mask = load_img(dir_augmented_data_masks + \"/\" + imgnm)\n"," ax.imshow(mask)\n"," j += 1\n"," plt.show()\n","\n","else:\n"," print(\"No augmentation will be used\")\n","\n"," "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the UNET_Model_from_\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(R+'WARNING: pretrained model does not exist')\n"," Use_pretrained_model = False\n"," \n","\n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(R+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["# **4. Train the network**\n","---\n","####**Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time. "]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Play this cell to prepare the model for training\n","\n","\n","# ------------------ Set the generators, model and logger ------------------\n","# This will take the image size and set that as a patch size (arguable...)\n","# Read image size (without actuall reading the data)\n","\n","(train_datagen, validation_datagen) = prepareGenerators(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height))\n","\n","\n","# This modelcheckpoint will only save the best model from the validation loss point of view\n","model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)\n","\n","print('Getting class weights...')\n","class_weights = getClassWeights(Training_target)\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","else:\n"," h5_file_path = None\n","\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Reduce learning rate on plateau ------------------------\n","\n","reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, mode='auto',\n"," patience=10, min_lr=0)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Define the model\n","model = unet(pretrained_weights = h5_file_path, \n"," input_size = (patch_width,patch_height,1), \n"," pooling_steps = pooling_steps, \n"," learning_rate = initial_learning_rate, \n"," class_weights = class_weights)\n","\n","config_model= model.optimizer.get_config()\n","print(config_model)\n","\n","\n","# ------------------ Failsafes ------------------\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)\n"," shutil.rmtree(full_model_path)\n","\n","os.makedirs(full_model_path)\n","os.makedirs(os.path.join(full_model_path,'Quality Control'))\n","\n","\n","# ------------------ Display ------------------\n","print('---------------------------- Main training parameters ----------------------------')\n","print('Number of epochs: '+str(number_of_epochs))\n","print('Batch size: '+str(batch_size))\n","print('Number of training dataset: '+str(number_of_training_dataset))\n","print('Number of training steps: '+str(number_of_steps))\n","print('Number of validation steps: '+str(validation_steps))\n","print('---------------------------- ------------------------ ----------------------------')\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","# history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs=epochs, callbacks=[model_checkpoint,csv_log], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs = number_of_epochs, callbacks=[model_checkpoint, reduce_lr], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","\n","# Save the last model\n","model.save(os.path.join(full_model_path, 'weights_last.hdf5'))\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n"," \n","\n","\n","# Displaying the time elapsed for training\n","print(\"------------------------------------------\")\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\", hour, \"hour(s)\", mins,\"min(s)\",round(sec),\"sec(s)\")\n","print(\"------------------------------------------\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","\n","full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n","if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","epochNumber = []\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'),bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.\n","\n"," The results for all QC examples can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n","\n","### **Thresholds for image masks**\n","\n"," Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","# ------------- Initialise folders ------------\n","# Create a quality control/Prediction Folder\n","prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n","if os.path.exists(prediction_QC_folder):\n"," shutil.rmtree(prediction_QC_folder)\n","\n","os.makedirs(prediction_QC_folder)\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model\n","unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Source_QC_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","#-----------------------------Calculate Metrics----------------------------------------#\n","\n","f = plt.figure(figsize=((5,5)))\n","\n","with open(os.path.join(full_QC_model_path,'Quality Control', 'QC_metrics_'+QC_model_name+'.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"File name\",\"IoU\", \"IoU-optimised threshold\"]) \n","\n"," # Initialise the lists \n"," filename_list = []\n"," best_threshold_list = []\n"," best_IoU_score_list = []\n","\n"," for filename in os.listdir(Source_QC_folder):\n","\n"," if not os.path.isdir(os.path.join(Source_QC_folder, filename)):\n"," print('Running QC on: '+filename)\n"," test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)\n","\n"," (threshold_list, iou_scores_per_threshold) = getIoUvsThreshold(os.path.join(prediction_QC_folder, prediction_prefix+filename), os.path.join(Target_QC_folder, filename))\n"," plt.plot(threshold_list,iou_scores_per_threshold, label=filename)\n","\n"," # Here we find which threshold yielded the highest IoU score for image n.\n"," best_IoU_score = max(iou_scores_per_threshold)\n"," best_threshold = iou_scores_per_threshold.index(best_IoU_score)\n","\n"," # Write the results in the CSV file\n"," writer.writerow([filename, str(best_IoU_score), str(best_threshold)])\n","\n"," # Here we append the best threshold and score to the lists\n"," filename_list.append(filename)\n"," best_IoU_score_list.append(best_IoU_score)\n"," best_threshold_list.append(best_threshold)\n","\n","# Display the IoV vs Threshold plot\n","plt.title('IoU vs. Threshold')\n","plt.ylabel('Threshold value')\n","plt.xlabel('IoU')\n","plt.legend()\n","plt.savefig(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = filename_list)\n","pdResults[\"IoU\"] = best_IoU_score_list\n","pdResults[\"IoU-optimised threshold\"] = best_threshold_list\n","\n","average_best_threshold = sum(best_threshold_list)/len(best_threshold_list)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(Source_QC_folder)):\n"," \n"," plt.figure(figsize=(25,5))\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)\n"," plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))\n"," test_prediction_mask = np.empty_like(test_prediction)\n"," test_prediction_mask[test_prediction > average_best_threshold] = 255\n"," test_prediction_mask[test_prediction <= average_best_threshold] = 0\n"," plt.imshow(test_prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_image, cmap='Greens')\n"," plt.imshow(test_prediction_mask, alpha=0.5, cmap='Purples')\n"," metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file][\"IoU\"],3)) + ' T: ' + str(round(pdResults.loc[file][\"IoU-optimised threshold\"])) + ')'\n"," plt.title(metrics_title)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","print('--------------------------------------------------------------')\n","print('Best average threshold is: '+str(round(average_best_threshold)))\n","print('--------------------------------------------------------------')\n","\n","pdResults.head()\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n"," Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.\n","\n"," **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["\n","\n","# ------------- Initial user input ------------\n","#@markdown ###Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","Data_folder = '' #@param {type:\"string\"}\n","Results_folder = '' #@param {type:\"string\"}\n","\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","# ------------- Failsafes ------------\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model and prepare generator\n","\n","\n","\n","unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Data_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))\n"," # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","\n","\n","def show_prediction_mask(file=os.listdir(Data_folder), threshold=(0,255,1)):\n","\n"," plt.figure(figsize=(18,6))\n"," # Wide-field\n"," plt.subplot(1,3,1)\n"," plt.axis('off')\n"," img_Source = plt.imread(os.path.join(Data_folder, file))\n"," plt.imshow(img_Source, cmap='gray')\n"," plt.title('Source image',fontsize=15)\n"," # Prediction\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))\n"," plt.imshow(img_Prediction, cmap='gray')\n"," plt.title('Prediction',fontsize=15)\n","\n"," # Thresholded mask\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," img_Mask = convert2Mask(img_Prediction, threshold)\n"," plt.imshow(img_Mask, cmap='gray')\n"," plt.title('Mask (Threshold: '+str(round(threshold))+')',fontsize=15)\n","\n","\n","interact(show_prediction_mask, continuous_update=False);\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"stS96mFZLMOU"},"source":["## **6.2. Export results as masks**\n","---\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"qb5ZmFstLNbR"},"source":["\n","# @markdown #Play this cell to save results as masks with the chosen threshold\n","threshold = 120#@param {type:\"number\"}\n","\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=threshold)\n","print('-------------------')\n","print('Masks were saved in: '+Results_folder)\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using 2D U-Net!**\n"]}]} \ No newline at end of file diff --git a/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb index c221e527..3eb5d11c 100644 --- a/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb @@ -1,2354 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "accelerator": "GPU", - "colab": { - "name": "U-Net_3D_ZeroCostDL4Mic.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.3" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "V9zNGvape2-I" - }, - "source": [ - "# **U-Net (3D)**\n", - " ---\n", - "\n", - " The 3D U-Net was first introduced by [Çiçek et al](https://arxiv.org/abs/1606.06650) for learning dense volumetric segmentations from sparsely annotated ground-truth data building upon the original U-Net architecture by [Ronneberger et al](https://arxiv.org/abs/1505.04597). \n", - "\n", - "**This particular implementation allows supervised learning between any two types of 3D image data. If you are interested in image segmentation of 2D datasets, you should use the 2D U-Net notebook instead.**\n", - "\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project ([ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki)) jointly developed by the [Jacquemet](https://cellmig.org/) and [Henriques](https://henriqueslab.github.io/) laboratories and created by Daniel Krentzel.\n", - "\n", - "This notebook is laregly based on the following paper: \n", - "\n", - "[**3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation**](https://arxiv.org/pdf/1606.06650.pdf) by Özgün Çiçek *et al.* published on arXiv in 2016\n", - "\n", - "The following two Python libraries play an important role in the notebook: \n", - "\n", - "1. [**Elasticdeform**](https://github.com/gvtulder/elasticdeform)\n", - " by Gijs van Tulder was used to augment the 3D training data using elastic grid-based deformations as described in the original 3D U-Net paper. \n", - "\n", - "2. [**Tifffile**](https://github.com/cgohlke/tifffile) by Christoph Gohlke is a great library for reading and writing TIFF files. \n", - "\n", - "3. [**Imgaug**](https://github.com/aleju/imgaug) by Alexander Jung *et al.* is an amazing library for image augmentation in machine learning - it is the most complete and extensive image augmentation package I have found to date. \n", - "\n", - "The [example dataset](https://www.epfl.ch/labs/cvlab/data/data-em/) represents a 5x5x5µm section taken from the CA1 hippocampus region of the brain with annotated mitochondria and was acquired by Graham Knott and Marco Cantoni at EPFL.\n", - "\n", - "\n", - "**Please also cite the original paper and relevant Python libraries when using or developing this notebook.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cells: \n", - "\n", - "**Text cells** provide information and can be modified by double-clicking the cell. You are currently reading a text cell. You can create a new one by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code which can be modfied by selecting the cell. To execute the cell, move your cursor to the `[]`-symbol on the left side of the cell (a play button should appear). Click it to execute the cell. Once the cell is fully executed, the animation stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "Three tabs are located on the upper left side of the notebook:\n", - "\n", - "1. *Table of contents* contains the structure of the notebook. Click the headers to move quickly between sections.\n", - "\n", - "2. *Code snippets* provides a wide array of example code specific to Google Colab. You can ignore this when using this notebook.\n", - "\n", - "3. *Files* displays the current working directory. We will mount your Google Drive in Section 1.2. so that you can access your files and save them permanently.\n", - "\n", - "**Important:** All uploaded files are purged once the runtime ends.\n", - "\n", - "**Note:** The directory *sample data* in *Files* contains default files. Do not upload anything there!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive by clicking *File* -> *Save a copy in Drive*.\n", - "\n", - "To **edit a cell**, double click on the text. This will either display the source code (in code cells) or the [markdown](https://colab.research.google.com/notebooks/markdown_guide.ipynb#scrollTo=70pYkR9LiOV0) (in text cells).\n", - "You can use `#` in code cells to comment out parts of the code. This allows you to keep the original piece of code while not executing it." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vNMDQHm0Ah-Z" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - "\n", - "As the network operates in three dimensions, certain consideration should be given to correctly pre-processing the data. Ensure that the structure of interest does not substantially change between slices - image volumes with isotropic pixelsizes are ideal for this architecture.\n", - "\n", - "Each image volume must be provided as an **8-bit** or **binary multipage TIFF file** to maintain the correct ordering of individual image slices. If more than one image volume has been annotated, source and target files must be named identically and placed in separate directories. In case only one image volume has been annotated, source and target file do not have to be placed in separate directories and can be named differently, as long as their paths are explicitly provided in Section 3. \n", - "\n", - "**Prepare two datasets** (*training* and *testing*) for quality control puproses. Make sure that the *testing* dataset does not overlap with the *training* dataset and is ideally sourced from a different acquisiton and sample to ensure robustness of the trained model. \n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "### **Directory structure**\n", - "\n", - "Make sure to adhere to one of the following directory structures. If only one annotated training volume exists, choose the first structure. In case more than one training volume is available, choose the second structure.\n", - "\n", - "**Structure 1:** Only one training volume\n", - "```\n", - "path/to/directory/with/one/training/volume\n", - "│--training_source.tif\n", - "│--training_target.tif\n", - "| \n", - "│--testing_source.tif\n", - "|--testing_target.tif \n", - "|\n", - "|--data_to_predict_on.tif\n", - "|--prediction_results.tif\n", - "\n", - "```\n", - "**Structure 2:** Various training volumes\n", - "```\n", - "path/to/directory/with/various/training/volumes\n", - "│--testing_source.tif\n", - "|--testing_target.tif \n", - "|\n", - "└───training\n", - "| └───source\n", - "| | |--training_volume_one.tif\n", - "| | |--training_volume_two.tif\n", - "| | |--...\n", - "| | |--training_volume_n.tif\n", - "| |\n", - "| └───target\n", - "| |--training_volume_one.tif\n", - "| |--training_volume_two.tif\n", - "| |--...\n", - "| |--training_volume_n.tif\n", - "|\n", - "|--data_to_predict_on.tif\n", - "|--prediction_results.tif\n", - "```\n", - "**Note:** Naming directories is completely up to you, as long as the paths are correctly specified throughout the notebook.\n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "### **Important note**\n", - "\n", - "* If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do so), you will need to run **Sections 1 - 4**, then use **Section 5** to assess the quality of your model and **Section 6** to run predictions using the model that you trained.\n", - "\n", - "* If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 5** to assess the quality of your model.\n", - "\n", - "* If you only wish to **Run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "fFdz-rHnQxld" - }, - "source": [ - "#@markdown ##**Download example dataset**\n", - "\n", - "#@markdown This usually takes a few minutes. The images are saved in *example_dataset*.\n", - "\n", - "import requests \n", - "import os\n", - "from tqdm.notebook import tqdm \n", - "\n", - "def make_directory(dir):\n", - " if not os.path.exists(dir):\n", - " os.makedirs(dir)\n", - "\n", - "def download_from_url(url, save_as):\n", - " file_url = url\n", - " r = requests.get(file_url, stream=True) \n", - " \n", - " with open(save_as, 'wb') as file: \n", - " for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=126875, ncols=1000):\n", - " if block:\n", - " file.write(block) \n", - "\n", - "\n", - "make_directory('example_dataset')\n", - "\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training.tif', 'example_dataset/training.tif')\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training_groundtruth.tif', 'example_dataset/training_groundtruth.tif')\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing.tif', 'example_dataset/testing.tif')\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing_groundtruth.tif', 'example_dataset/testing_groundtruth.tif')\n", - "\n", - "print('Example dataset successfully downloaded!')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "---\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BCPhV-pe-syw" - }, - "source": [ - "\n", - "## **1.1. Check GPU access and Python version**\n", - "---\n", - "\n", - "By default, Colab sessions run Python 3 with GPU acceleration. You can manually set this by:\n", - "\n", - "1. Going to **Runtime -> Change runtime type**\n", - "\n", - "2. **Runtime type: Python 3** *(This notebook uses Python 3)*\n", - "\n", - "3. **Accelator: GPU** *(Graphics Processing Unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "r9eqe5TazD5o" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "%tensorflow_version 1.x\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UBrnApIUBgxv" - }, - "source": [ - "## **1.2. Mount Google Drive**\n", - "---\n", - " To use this notebook with your **own data**, place it in a folder on **Google Drive** following one of the directory structures outlined in **Section 0**.\n", - "\n", - "1. **Run** the **cell** below to mount your Google Drive and follow the link. \n", - "\n", - "2. **Sign in** to your Google account and press 'Allow'. \n", - "\n", - "3. Next, copy the **authorization code**, paste it into the cell and press enter. This will allow Colab to read and write data from and to your Google Drive. \n", - "\n", - "4. Once this is done, your data can be viewed in the **Files tab** on the top left of the notebook after hitting 'Refresh'." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "01Djr8v-5pPk" - }, - "source": [ - "#@markdown ##Run this cell to connect your Google Drive to Colab\n", - "\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive', force_remount=True)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "Vspvj5Q2ijd4" - }, - "source": [ - "#@markdown ##Unzip pre-trained model directory\n", - "\n", - "#@markdown 1. Upload a zipped model directory using the *Files* tab\n", - "#@markdown 2. Run this cell to unzip your model file\n", - "#@markdown 3. The model directory will appear in the *Files* tab \n", - "\n", - "from google.colab import files\n", - "\n", - "zipped_model_file = \"\" #@param {type:\"string\"}\n", - "\n", - "!unzip \"$zipped_model_file\"" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **2. Install 3D U-Net dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "3u2mXn3XsWzd", - "cellView": "form" - }, - "source": [ - "#@markdown ##Install dependencies and instantiate network\n", - "Notebook_version = ['1.11']\n", - "#Put the imported code and libraries here\n", - "!pip install fpdf\n", - "from __future__ import absolute_import, division, print_function, unicode_literals\n", - "\n", - "try:\n", - " import elasticdeform\n", - "except:\n", - " !pip install elasticdeform\n", - " import elasticdeform\n", - "\n", - "try:\n", - " import tifffile\n", - "except:\n", - " !pip install tifffile\n", - " import tifffile\n", - "\n", - "try:\n", - " import imgaug.augmenters as iaa\n", - "except:\n", - " !pip install imgaug\n", - " import imgaug.augmenters as iaa\n", - "\n", - "try:\n", - " import bcolors\n", - "except:\n", - " !pip install bcolors\n", - " import bcolors\n", - "\n", - "import os\n", - "import csv\n", - "import random\n", - "import h5py\n", - "import imageio\n", - "import math\n", - "import shutil\n", - "\n", - "import pandas as pd\n", - "from glob import glob\n", - "from tqdm import tqdm\n", - "\n", - "from skimage import transform\n", - "from skimage import exposure\n", - "from skimage import color\n", - "from skimage import io\n", - "\n", - "from scipy.ndimage import zoom\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import numpy as np\n", - "import tensorflow as tf\n", - "\n", - "from keras import backend as K\n", - "\n", - "from keras.layers import Conv3D\n", - "from keras.layers import BatchNormalization\n", - "from keras.layers import ReLU\n", - "from keras.layers import MaxPooling3D\n", - "from keras.layers import Conv3DTranspose\n", - "from keras.layers import Input\n", - "from keras.layers import Concatenate\n", - "\n", - "from keras.models import Model\n", - "\n", - "from keras.utils import Sequence\n", - "\n", - "from keras.callbacks import ModelCheckpoint\n", - "from keras.callbacks import CSVLogger\n", - "from keras.callbacks import Callback\n", - "\n", - "from keras.metrics import RootMeanSquaredError\n", - "\n", - "from ipywidgets import interact\n", - "from ipywidgets import interactive\n", - "from ipywidgets import fixed\n", - "from ipywidgets import interact_manual \n", - "import ipywidgets as widgets\n", - "\n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "import subprocess\n", - "from pip._internal.operations.freeze import freeze\n", - "import time\n", - "\n", - "from skimage import io\n", - "import matplotlib\n", - "\n", - "print(\"Dependencies installed and imported.\")\n", - "\n", - "# Define MultiPageTiffGenerator class\n", - "class MultiPageTiffGenerator(Sequence):\n", - "\n", - " def __init__(self,\n", - " source_path,\n", - " target_path,\n", - " batch_size=1,\n", - " shape=(128,128,32,1),\n", - " augment=False,\n", - " augmentations=[],\n", - " deform_augment=False,\n", - " deform_augmentation_params=(5,3,4),\n", - " val_split=0.2,\n", - " is_val=False,\n", - " random_crop=True,\n", - " downscale=1,\n", - " binary_target=False):\n", - "\n", - " # If directory with various multi-page tiffiles is provided read as list\n", - " if os.path.isfile(source_path):\n", - " self.dir_flag = False\n", - " self.source = tifffile.imread(source_path)\n", - " if binary_target:\n", - " self.target = tifffile.imread(target_path).astype(np.bool)\n", - " else:\n", - " self.target = tifffile.imread(target_path)\n", - "\n", - " elif os.path.isdir(source_path):\n", - " self.dir_flag = True\n", - " self.source_dir_list = glob(os.path.join(source_path, '*'))\n", - " self.target_dir_list = glob(os.path.join(target_path, '*'))\n", - "\n", - " self.source_dir_list.sort()\n", - " self.target_dir_list.sort()\n", - "\n", - " self.shape = shape\n", - " self.batch_size = batch_size\n", - " self.augment = augment\n", - " self.val_split = val_split\n", - " self.is_val = is_val\n", - " self.random_crop = random_crop\n", - " self.downscale = downscale\n", - " self.binary_target = binary_target\n", - " self.deform_augment = deform_augment\n", - " self.on_epoch_end()\n", - " \n", - " if self.augment:\n", - " # pass list of augmentation functions \n", - " self.seq = iaa.Sequential(augmentations, random_order=True) # apply augmenters in random order\n", - " if self.deform_augment:\n", - " self.deform_sigma, self.deform_points, self.deform_order = deform_augmentation_params\n", - "\n", - " def __len__(self):\n", - " # If various multi-page tiff files provided sum all images within each\n", - " if self.augment:\n", - " augment_factor = 4\n", - " else:\n", - " augment_factor = 1\n", - " \n", - " if self.dir_flag:\n", - " num_of_imgs = 0\n", - " for tiff_path in self.source_dir_list:\n", - " num_of_imgs += tifffile.imread(tiff_path).shape[0]\n", - " xy_shape = tifffile.imread(self.source_dir_list[0]).shape[1:]\n", - "\n", - " if self.is_val:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = xy_shape[0] * xy_shape[1] * self.val_split * num_of_imgs\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - " else:\n", - " return math.floor(self.val_split * num_of_imgs / self.batch_size)\n", - " else:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = xy_shape[0] * xy_shape[1] * (1 - self.val_split) * num_of_imgs\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - "\n", - " else:\n", - " return math.floor(augment_factor*(1 - self.val_split) * num_of_imgs/self.batch_size)\n", - " else:\n", - " if self.is_val:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = self.source.shape[0] * self.source.shape[1] * self.val_split * self.source.shape[2]\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - " else:\n", - " return math.floor((self.val_split * self.source.shape[0] / self.batch_size))\n", - " else:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = self.source.shape[0] * self.source.shape[1] * (1 - self.val_split) * self.source.shape[2]\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - " else:\n", - " return math.floor(augment_factor * (1 - self.val_split) * self.source.shape[0] / self.batch_size)\n", - "\n", - " def __getitem__(self, idx):\n", - " source_batch = np.empty((self.batch_size,\n", - " self.shape[0],\n", - " self.shape[1],\n", - " self.shape[2],\n", - " self.shape[3]))\n", - " target_batch = np.empty((self.batch_size,\n", - " self.shape[0],\n", - " self.shape[1],\n", - " self.shape[2],\n", - " self.shape[3]))\n", - "\n", - " for batch in range(self.batch_size):\n", - " # Modulo operator ensures IndexError is avoided\n", - " stack_start = self.batch_list[(idx+batch*self.shape[2])%len(self.batch_list)]\n", - "\n", - " if self.dir_flag:\n", - " self.source = tifffile.imread(self.source_dir_list[stack_start[0]])\n", - " if self.binary_target:\n", - " self.target = tifffile.imread(self.target_dir_list[stack_start[0]]).astype(np.bool)\n", - " else:\n", - " self.target = tifffile.imread(self.target_dir_list[stack_start[0]])\n", - "\n", - " src_list = []\n", - " tgt_list = []\n", - " for i in range(stack_start[1], stack_start[1]+self.shape[2]):\n", - " src = self.source[i]\n", - " src = transform.downscale_local_mean(src, (self.downscale, self.downscale))\n", - " if not self.random_crop:\n", - " src = transform.resize(src, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n", - " src = self._min_max_scaling(src)\n", - " src_list.append(src)\n", - "\n", - " tgt = self.target[i]\n", - " tgt = transform.downscale_local_mean(tgt, (self.downscale, self.downscale))\n", - " if not self.random_crop:\n", - " tgt = transform.resize(tgt, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n", - " if not self.binary_target:\n", - " tgt = self._min_max_scaling(tgt)\n", - " tgt_list.append(tgt)\n", - "\n", - " if self.random_crop:\n", - " if src.shape[0] == self.shape[0]:\n", - " x_rand = 0\n", - " if src.shape[1] == self.shape[1]:\n", - " y_rand = 0\n", - " if src.shape[0] > self.shape[0]:\n", - " x_rand = np.random.randint(src.shape[0] - self.shape[0])\n", - " if src.shape[1] > self.shape[1]:\n", - " y_rand = np.random.randint(src.shape[1] - self.shape[1])\n", - " if src.shape[0] < self.shape[0] or src.shape[1] < self.shape[1]:\n", - " raise ValueError('Patch shape larger than (downscaled) source shape')\n", - " \n", - " for i in range(self.shape[2]):\n", - " if self.random_crop:\n", - " src = src_list[i]\n", - " tgt = tgt_list[i]\n", - " src_crop = src[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n", - " tgt_crop = tgt[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n", - " else:\n", - " src_crop = src_list[i]\n", - " tgt_crop = tgt_list[i]\n", - "\n", - " source_batch[batch,:,:,i,0] = src_crop\n", - " target_batch[batch,:,:,i,0] = tgt_crop\n", - "\n", - " if self.augment:\n", - " # On-the-fly data augmentation\n", - " source_batch, target_batch = self.augment_volume(source_batch, target_batch)\n", - "\n", - " # Data augmentation by reversing stack\n", - " if np.random.random() > 0.5:\n", - " source_batch, target_batch = source_batch[::-1], target_batch[::-1]\n", - " \n", - " # Data augmentation by elastic deformation\n", - " if np.random.random() > 0.5 and self.deform_augment:\n", - " source_batch, target_batch = self.deform_volume(source_batch, target_batch)\n", - " \n", - " if not self.binary_target:\n", - " target_batch = self._min_max_scaling(target_batch)\n", - " \n", - " return self._min_max_scaling(source_batch), target_batch\n", - " \n", - " else:\n", - " return source_batch, target_batch\n", - "\n", - " def on_epoch_end(self):\n", - " # Validation split performed here\n", - " self.batch_list = []\n", - " # Create batch_list of all combinations of tifffile and stack position\n", - " if self.dir_flag:\n", - " for i in range(len(self.source_dir_list)):\n", - " num_of_pages = tifffile.imread(self.source_dir_list[i]).shape[0]\n", - " if self.is_val:\n", - " start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n", - " for j in range(start_page, num_of_pages-self.shape[2]):\n", - " self.batch_list.append([i, j])\n", - " else:\n", - " last_page = math.floor((1-self.val_split)*num_of_pages)\n", - " for j in range(last_page-self.shape[2]):\n", - " self.batch_list.append([i, j])\n", - " else:\n", - " num_of_pages = self.source.shape[0]\n", - " if self.is_val:\n", - " start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n", - " for j in range(start_page, num_of_pages-self.shape[2]):\n", - " self.batch_list.append([0, j])\n", - "\n", - " else:\n", - " last_page = math.floor((1-self.val_split)*num_of_pages)\n", - " for j in range(last_page-self.shape[2]):\n", - " self.batch_list.append([0, j])\n", - " \n", - " if self.is_val and (len(self.batch_list) <= 0):\n", - " raise ValueError('validation_split too small! Increase val_split or decrease z-depth')\n", - " random.shuffle(self.batch_list)\n", - " \n", - " def _min_max_scaling(self, data):\n", - " n = data - np.min(data)\n", - " d = np.max(data) - np.min(data) \n", - " \n", - " return n/d\n", - " \n", - " def class_weights(self):\n", - " ones = 0\n", - " pixels = 0\n", - "\n", - " if self.dir_flag:\n", - " for i in range(len(self.target_dir_list)):\n", - " tgt = tifffile.imread(self.target_dir_list[i]).astype(np.bool)\n", - " ones += np.sum(tgt)\n", - " pixels += tgt.shape[0]*tgt.shape[1]*tgt.shape[2]\n", - " else:\n", - " ones = np.sum(self.target)\n", - " pixels = self.target.shape[0]*self.target.shape[1]*self.target.shape[2]\n", - " p_ones = ones/pixels\n", - " p_zeros = 1-p_ones\n", - "\n", - " # Return swapped probability to increase weight of unlikely class\n", - " return p_ones, p_zeros\n", - "\n", - " def deform_volume(self, src_vol, tgt_vol):\n", - " [src_dfrm, tgt_dfrm] = elasticdeform.deform_random_grid([src_vol, tgt_vol],\n", - " axis=(1, 2, 3),\n", - " sigma=self.deform_sigma,\n", - " points=self.deform_points,\n", - " order=self.deform_order)\n", - " if self.binary_target:\n", - " tgt_dfrm = tgt_dfrm > 0.1\n", - " \n", - " return self._min_max_scaling(src_dfrm), tgt_dfrm \n", - "\n", - " def augment_volume(self, src_vol, tgt_vol):\n", - " src_vol_aug = np.empty(src_vol.shape)\n", - " tgt_vol_aug = np.empty(tgt_vol.shape)\n", - "\n", - " for i in range(src_vol.shape[3]):\n", - " src_vol_aug[:,:,:,i,0], tgt_vol_aug[:,:,:,i,0] = self.seq(images=src_vol[:,:,:,i,0].astype('float16'), \n", - " segmentation_maps=tgt_vol[:,:,:,i,0].astype(bool))\n", - " return self._min_max_scaling(src_vol_aug), tgt_vol_aug\n", - "\n", - " def sample_augmentation(self, idx):\n", - " src, tgt = self.__getitem__(idx)\n", - "\n", - " src_aug, tgt_aug = self.augment_volume(src, tgt)\n", - " \n", - " if self.deform_augment:\n", - " src_aug, tgt_aug = self.deform_volume(src_aug, tgt_aug)\n", - "\n", - " return src_aug, tgt_aug \n", - "\n", - "# Define custom loss and dice coefficient\n", - "def dice_coefficient(y_true, y_pred):\n", - " eps = 1e-6\n", - " y_true_f = K.flatten(y_true)\n", - " y_pred_f = K.flatten(y_pred)\n", - " intersection = K.sum(y_true_f*y_pred_f)\n", - "\n", - " return (2.*intersection)/(K.sum(y_true_f*y_true_f)+K.sum(y_pred_f*y_pred_f)+eps)\n", - "\n", - "def weighted_binary_crossentropy(zero_weight, one_weight):\n", - " def _weighted_binary_crossentropy(y_true, y_pred):\n", - " binary_crossentropy = K.binary_crossentropy(y_true, y_pred)\n", - "\n", - " weight_vector = y_true*one_weight+(1.-y_true)*zero_weight\n", - " weighted_binary_crossentropy = weight_vector*binary_crossentropy\n", - "\n", - " return K.mean(weighted_binary_crossentropy)\n", - "\n", - " return _weighted_binary_crossentropy\n", - "\n", - "# Custom callback showing sample prediction\n", - "class SampleImageCallback(Callback):\n", - "\n", - " def __init__(self, model, sample_data, model_path, save=False):\n", - " self.model = model\n", - " self.sample_data = sample_data\n", - " self.model_path = model_path\n", - " self.save = save\n", - "\n", - " def on_epoch_end(self, epoch, logs={}):\n", - " sample_predict = self.model.predict_on_batch(self.sample_data)\n", - "\n", - " f=plt.figure(figsize=(16,8))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(self.sample_data[0,:,:,0,0], interpolation='nearest', cmap='gray')\n", - " plt.title('Sample source')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(sample_predict[0,:,:,0,0], interpolation='nearest', cmap='magma')\n", - " plt.title('Predicted target')\n", - " plt.axis('off');\n", - "\n", - " plt.show()\n", - "\n", - " if self.save:\n", - " plt.savefig(self.model_path + '/epoch_' + str(epoch+1) + '.png')\n", - "\n", - "\n", - "# Define Unet3D class\n", - "class Unet3D:\n", - "\n", - " def __init__(self,\n", - " shape=(256,256,16,1)):\n", - " if isinstance(shape, str):\n", - " shape = eval(shape)\n", - "\n", - " self.shape = shape\n", - " \n", - " input_tensor = Input(self.shape, name='input')\n", - "\n", - " self.model = self.unet_3D(input_tensor)\n", - "\n", - " def down_block_3D(self, input_tensor, filters):\n", - " x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(input_tensor)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " return x\n", - "\n", - " def up_block_3D(self, input_tensor, concat_layer, filters):\n", - " x = Conv3DTranspose(filters, kernel_size=(2,2,2), strides=(2,2,2))(input_tensor)\n", - "\n", - " x = Concatenate()([x, concat_layer])\n", - "\n", - " x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(x)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " return x\n", - "\n", - " def unet_3D(self, input_tensor, filters=32):\n", - " d1 = self.down_block_3D(input_tensor, filters=filters)\n", - " p1 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d1)\n", - " d2 = self.down_block_3D(p1, filters=filters*2)\n", - " p2 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d2)\n", - " d3 = self.down_block_3D(p2, filters=filters*4)\n", - " p3 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d3)\n", - "\n", - " d4 = self.down_block_3D(p3, filters=filters*8)\n", - "\n", - " u1 = self.up_block_3D(d4, d3, filters=filters*4)\n", - " u2 = self.up_block_3D(u1, d2, filters=filters*2)\n", - " u3 = self.up_block_3D(u2, d1, filters=filters)\n", - "\n", - " output_tensor = Conv3D(filters=1, kernel_size=(1,1,1), activation='sigmoid')(u3)\n", - "\n", - " return Model(inputs=[input_tensor], outputs=[output_tensor])\n", - "\n", - " def summary(self):\n", - " return self.model.summary()\n", - "\n", - " # Pass generators instead\n", - " def train(self, \n", - " epochs, \n", - " batch_size, \n", - " train_generator,\n", - " val_generator, \n", - " model_path, \n", - " model_name,\n", - " optimizer='adam',\n", - " loss='weighted_binary_crossentropy',\n", - " metrics='dice',\n", - " ckpt_period=1, \n", - " save_best_ckpt_only=False, \n", - " ckpt_path=None):\n", - "\n", - " class_weight_zero, class_weight_one = train_generator.class_weights()\n", - " \n", - " if loss == 'weighted_binary_crossentropy':\n", - " loss = weighted_binary_crossentropy(class_weight_zero, class_weight_one)\n", - " \n", - " if metrics == 'dice':\n", - " metrics = dice_coefficient\n", - "\n", - " self.model.compile(optimizer=optimizer,\n", - " loss=loss,\n", - " metrics=[metrics])\n", - "\n", - " if ckpt_path is not None:\n", - " self.model.load_weights(ckpt_path)\n", - "\n", - " full_model_path = os.path.join(model_path, model_name)\n", - "\n", - " if not os.path.exists(full_model_path):\n", - " os.makedirs(full_model_path)\n", - " \n", - " log_dir = full_model_path + '/Quality Control'\n", - "\n", - " if not os.path.exists(log_dir):\n", - " os.makedirs(log_dir)\n", - " \n", - " ckpt_dir = full_model_path + '/ckpt'\n", - "\n", - " if not os.path.exists(ckpt_dir):\n", - " os.makedirs(ckpt_dir)\n", - "\n", - " csv_out_name = log_dir + '/training_evaluation.csv'\n", - " if ckpt_path is None:\n", - " csv_logger = CSVLogger(csv_out_name)\n", - " else:\n", - " csv_logger = CSVLogger(csv_out_name, append=True)\n", - "\n", - " if save_best_ckpt_only:\n", - " ckpt_name = ckpt_dir + '/' + model_name + '.hdf5'\n", - " else:\n", - " ckpt_name = ckpt_dir + '/' + model_name + '_epoch_{epoch:02d}_val_loss_{val_loss:.4f}.hdf5'\n", - " \n", - " model_ckpt = ModelCheckpoint(ckpt_name,\n", - " verbose=1,\n", - " period=ckpt_period,\n", - " save_best_only=save_best_ckpt_only,\n", - " save_weights_only=True)\n", - "\n", - " sample_batch, __ = val_generator.__getitem__(random.randint(0, len(val_generator)))\n", - " sample_img = SampleImageCallback(self.model, \n", - " sample_batch, \n", - " model_path)\n", - "\n", - " self.model.fit_generator(generator=train_generator,\n", - " validation_data=val_generator,\n", - " validation_steps=math.floor(len(val_generator)/batch_size),\n", - " epochs=epochs,\n", - " callbacks=[csv_logger,\n", - " model_ckpt,\n", - " sample_img])\n", - "\n", - " last_ckpt_name = ckpt_dir + '/' + model_name + '_last.hdf5'\n", - " self.model.save_weights(last_ckpt_name)\n", - "\n", - " def _min_max_scaling(self, data):\n", - " n = data - np.min(data)\n", - " d = np.max(data) - np.min(data) \n", - " \n", - " return n/d\n", - "\n", - " def predict(self, \n", - " input, \n", - " ckpt_path, \n", - " z_range=None, \n", - " downscaling=None, \n", - " true_patch_size=None):\n", - "\n", - " self.model.load_weights(ckpt_path)\n", - "\n", - " if isinstance(downscaling, str):\n", - " downscaling = eval(downscaling)\n", - "\n", - " if math.isnan(downscaling):\n", - " downscaling = None\n", - "\n", - " if isinstance(true_patch_size, str):\n", - " true_patch_size = eval(true_patch_size)\n", - " \n", - " if not isinstance(true_patch_size, tuple): \n", - " if math.isnan(true_patch_size):\n", - " true_patch_size = None\n", - "\n", - " if isinstance(input, str):\n", - " src_volume = tifffile.imread(input)\n", - " elif isinstance(input, np.ndarray):\n", - " src_volume = input\n", - " else:\n", - " raise TypeError('Input is not path or numpy array!')\n", - " \n", - " in_size = src_volume.shape\n", - "\n", - " if downscaling or true_patch_size is not None:\n", - " x_scaling = 0\n", - " y_scaling = 0\n", - "\n", - " if true_patch_size is not None:\n", - " x_scaling += true_patch_size[0]/self.shape[0]\n", - " y_scaling += true_patch_size[1]/self.shape[1]\n", - " if downscaling is not None:\n", - " x_scaling += downscaling\n", - " y_scaling += downscaling\n", - "\n", - " src_list = []\n", - " for i in range(src_volume.shape[0]):\n", - " src_list.append(transform.downscale_local_mean(src_volume[i], (int(x_scaling), int(y_scaling))))\n", - " src_volume = np.array(src_list) \n", - "\n", - " if z_range is not None:\n", - " src_volume = src_volume[z_range[0]:z_range[1]]\n", - "\n", - " src_volume = self._min_max_scaling(src_volume) \n", - "\n", - " src_array = np.zeros((1,\n", - " math.ceil(src_volume.shape[1]/self.shape[0])*self.shape[0], \n", - " math.ceil(src_volume.shape[2]/self.shape[1])*self.shape[1],\n", - " math.ceil(src_volume.shape[0]/self.shape[2])*self.shape[2], \n", - " self.shape[3]))\n", - "\n", - " for i in range(src_volume.shape[0]):\n", - " src_array[0,:src_volume.shape[1],:src_volume.shape[2],i,0] = src_volume[i]\n", - "\n", - " pred_array = np.empty(src_array.shape)\n", - "\n", - " for i in range(math.ceil(src_volume.shape[1]/self.shape[0])):\n", - " for j in range(math.ceil(src_volume.shape[2]/self.shape[1])):\n", - " for k in range(math.ceil(src_volume.shape[0]/self.shape[2])):\n", - " pred_temp = self.model.predict(src_array[:,\n", - " i*self.shape[0]:i*self.shape[0]+self.shape[0],\n", - " j*self.shape[1]:j*self.shape[1]+self.shape[1],\n", - " k*self.shape[2]:k*self.shape[2]+self.shape[2]])\n", - " pred_array[:,\n", - " i*self.shape[0]:i*self.shape[0]+self.shape[0],\n", - " j*self.shape[1]:j*self.shape[1]+self.shape[1],\n", - " k*self.shape[2]:k*self.shape[2]+self.shape[2]] = pred_temp\n", - " \n", - " pred_volume = np.rollaxis(np.squeeze(pred_array), -1)[:src_volume.shape[0],:src_volume.shape[1],:src_volume.shape[2]] \n", - "\n", - " if downscaling is not None:\n", - " pred_list = []\n", - " for i in range(pred_volume.shape[0]):\n", - " pred_list.append(transform.resize(pred_volume[i], (in_size[1], in_size[2]), preserve_range=True))\n", - " pred_volume = np.array(pred_list)\n", - "\n", - " return pred_volume\n", - "\n", - "\n", - "# -------------- Other definitions -----------\n", - "W = '\\033[0m' # white (normal)\n", - "R = '\\033[31m' # red\n", - "prediction_prefix = 'Predicted_'\n", - "\n", - "\n", - "print('-------------------')\n", - "print('U-Net 3D and dependencies installed.')\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - " \n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "if Notebook_version == list(Latest_notebook_version.columns):\n", - " print(\"This notebook is up-to-date.\")\n", - "\n", - "if not Notebook_version == list(Latest_notebook_version.columns):\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "\n", - "# Exporting requirements.txt for local run\n", - "!pip freeze > requirements.txt\n", - " " - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Fw0kkTU6CsU4" - }, - "source": [ - "# **3. Select your model parameters**\n", - "\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Vpc-4zCSpvHK" - }, - "source": [ - "## **3.1. Choosing parameters**\n", - "\n", - "---\n", - "\n", - "### **Paths to training data and model**\n", - "\n", - "* **`training_source`** and **`training_target`** specify the paths to the training data. They can either be a single multipage TIFF file each or directories containing various multipage TIFF files in which case target and source files must be named identically within the respective directories. See Section 0 for a detailed description of the necessary directory structure.\n", - "\n", - "* **`model_name`** will be used when naming checkpoints. Adhere to a `lower_case_with_underscores` naming convention and beware of using the name of an existing model within the same folder, as it will be overwritten.\n", - "\n", - "* **`model_path`** specifies the directory where the model checkpoints and quality control logs will be saved.\n", - "\n", - "\n", - "**Note:** You can copy paths from the 'Files' tab by right-clicking any folder or file and selecting 'Copy path'. \n", - "\n", - "### **Training parameters**\n", - "\n", - "* **`number_of_epochs`** is the number of times the entire training data will be seen by the model. *Default: >100*\n", - "\n", - "* **`batch_size`** is the number of training patches of size `patch_size` that will be bundled together at each training step. *Default: 1*\n", - "\n", - "* **`patch_size`** specifies the size of the three-dimensional training patches in (x, y, z) that will be fed to the model. In order to avoid errors, preferably use a square aspect ratio or stick to the advanced parameters. *Default: <(512, 512, 16)*\n", - "\n", - "* **`validation_split_in_percent`** is the relative amount of training data that will be set aside for validation. *Default: 20* \n", - "\n", - "* **`downscaling_in_xy`** downscales the training images by the specified amount in x and y. This is useful to enforce isotropic pixel-size if the z resolution is lower than the xy resolution in the training volume or to capture a larger field-of-view while decreasing the memory requirements. *Default: 1*\n", - "\n", - "* **`image_pre_processing`** selects whether the training images are randomly cropped during training or resized to `patch_size`. Choose `randomly crop to patch_size` to shrink the field-of-view of the training images to the `patch_size`. *Default: resize to patch_size* \n", - "\n", - "* **`binary_target`** forces the target image to be binary. Choose this if your model is trained to perform binary segmentation tasks *Default: True* \n", - "\n", - "* **`loss_function`** defines the loss. Read more [here](https://keras.io/api/losses/). *Default: weighted_binary_crossentropy* \n", - "\n", - "* **`metrics`** defines the metric. Read more [here](https://keras.io/api/metrics/). *Default: dice* \n", - "\n", - "* **`optimizer`** defines the optimizer. Read more [here](https://keras.io/api/optimizers/). *Default: adam* \n", - "\n", - "**Note:** If a *ResourceExhaustedError* is raised in Section 4.1. during training, decrease `batch_size` and `patch_size`. Decrease `batch_size` first and if the error persists at `batch_size = 1`, reduce the `patch_size`. \n", - "\n", - "**Note:** The number of steps per epoch are calculated as `floor(augment_factor * (1 - validation_split) * num_of_slices / batch_size)` if `image_pre_processing` is `resize to patch_size` where `augment_factor` is three if `apply_data_augmentation` is `True` and one otherwise. The `num_of_slices` is the overall number of slices (z-depth) in the training set across all provided image volumes. If `image_pre_processing` is `randomly crop to patch_size`, the number of steps per epoch are calculated as `floor(augment_factor * volume / (crop_volume * batch_size))` where `volume` is the overall volume of the training data in pixels accounting for the validation split and `crop_volume` is defined as the volume in pixels based on the specified `patch_size`." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ewpNJ_I0Mv47", - "cellView": "form" - }, - "source": [ - "#@markdown ###Path to training data:\n", - "training_source = \"\" #@param {type:\"string\"}\n", - "training_target = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ---\n", - "\n", - "#@markdown ###Model name and path to model folder:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "full_model_path = os.path.join(model_path, model_name)\n", - "\n", - "#@markdown ---\n", - "\n", - "#@markdown ###Training parameters\n", - "number_of_epochs = 150#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Default advanced parameters\n", - "use_default_advanced_parameters = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown If not, please change:\n", - "\n", - "batch_size = 1#@param {type:\"number\"}\n", - "patch_size = (256,256,16) #@param {type:\"number\"} # in pixels\n", - "training_shape = patch_size + (1,)\n", - "image_pre_processing = 'randomly crop to patch_size' #@param [\"randomly crop to patch_size\", \"resize to patch_size\"]\n", - "\n", - "validation_split_in_percent = 20 #@param{type:\"number\"}\n", - "downscaling_in_xy = 2#@param {type:\"number\"} # in pixels\n", - "\n", - "binary_target = True #@param {type:\"boolean\"}\n", - "\n", - "loss_function = 'weighted_binary_crossentropy' #@param [\"weighted_binary_crossentropy\", \"binary_crossentropy\", \"categorical_crossentropy\", \"sparse_categorical_crossentropy\", \"mean_squared_error\", \"mean_absolute_error\"]\n", - "\n", - "metrics = 'dice' #@param [\"dice\", \"accuracy\"]\n", - "\n", - "optimizer = 'adam' #@param [\"adam\", \"sgd\", \"rmsprop\"]\n", - "\n", - "\n", - "if image_pre_processing == \"randomly crop to patch_size\":\n", - " random_crop = True\n", - "else:\n", - " random_crop = False\n", - "\n", - "if use_default_advanced_parameters: \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 1\n", - " training_shape = (256,256,8,1)\n", - " validation_split_in_percent = 20\n", - " downscaling_in_xy = 1\n", - " random_crop = True\n", - " binary_target = True\n", - " loss_function = 'weighted_binary_crossentropy'\n", - " metrics = 'dice'\n", - " optimizer = 'adam'\n", - "\n", - "#@markdown ###Checkpointing parameters\n", - "checkpointing_period = 1 #@param {type:\"number\"}\n", - "\n", - "#@markdown If chosen only the best checkpoint is saved, otherwise a checkpoint is saved every checkpoint_period epochs:\n", - "save_best_only = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###Resume training\n", - "#@markdown Choose if training was interrupted:\n", - "resume_training = False #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###Transfer learning\n", - "#@markdown For transfer learning, do not select resume_training and specify a checkpoint_path below:\n", - "checkpoint_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if resume_training and checkpoint_path != \"\":\n", - " print('If resume_training is True while checkpoint_path is specified, resume_training will be set to False!')\n", - " resume_training = False\n", - " \n", - "\n", - "# Retrieve last checkpoint\n", - "if resume_training:\n", - " try:\n", - " ckpt_dir_list = glob(full_model_path + '/ckpt/*')\n", - " ckpt_dir_list.sort()\n", - " last_ckpt_path = ckpt_dir_list[-1]\n", - " print('Training will resume from checkpoint:', os.path.basename(last_ckpt_path))\n", - " except IndexError:\n", - " last_ckpt_path=None\n", - " print('CheckpointError: No previous checkpoints were found, training from scratch.')\n", - "elif not resume_training and checkpoint_path != \"\":\n", - " last_ckpt_path = checkpoint_path\n", - " assert os.path.isfile(last_ckpt_path), 'checkpoint_path does not exist!'\n", - "else:\n", - " last_ckpt_path=None\n", - "\n", - "# Instantiate Unet3D \n", - "model = Unet3D(shape=training_shape)\n", - "\n", - "#here we check that no model with the same name already exist\n", - "if not resume_training and os.path.exists(full_model_path): \n", - " print(bcolors.WARNING+'The model folder already exists and will be overwritten.')\n", - " # print('!! WARNING: Folder already exists and will be overwritten !!') \n", - " # shutil.rmtree(full_model_path)\n", - "\n", - "# if not os.path.exists(full_model_path):\n", - "# os.makedirs(full_model_path)\n", - "\n", - "# Show sample image\n", - "if os.path.isdir(training_source):\n", - " training_source_sample = sorted(glob(os.path.join(training_source, '*')))[0]\n", - " training_target_sample = sorted(glob(os.path.join(training_target, '*')))[0]\n", - "else:\n", - " training_source_sample = training_source\n", - " training_target_sample = training_target\n", - "\n", - "src_sample = tifffile.imread(training_source_sample)\n", - "src_sample = model._min_max_scaling(src_sample)\n", - "if binary_target:\n", - " tgt_sample = tifffile.imread(training_target_sample).astype(np.bool)\n", - "else:\n", - " tgt_sample = tifffile.imread(training_target_sample)\n", - "\n", - "src_down = transform.downscale_local_mean(src_sample[0], (downscaling_in_xy, downscaling_in_xy))\n", - "tgt_down = transform.downscale_local_mean(tgt_sample[0], (downscaling_in_xy, downscaling_in_xy)) \n", - "\n", - "if random_crop:\n", - " true_patch_size = None\n", - "\n", - " if src_down.shape[0] == training_shape[0]:\n", - " x_rand = 0\n", - " if src_down.shape[1] == training_shape[1]:\n", - " y_rand = 0\n", - " if src_down.shape[0] > training_shape[0]:\n", - " x_rand = np.random.randint(src_down.shape[0] - training_shape[0])\n", - " if src_down.shape[1] > training_shape[1]:\n", - " y_rand = np.random.randint(src_down.shape[1] - training_shape[1])\n", - " if src_down.shape[0] < training_shape[0] or src_down.shape[1] < training_shape[1]:\n", - " raise ValueError('Patch shape larger than (downscaled) source shape')\n", - "else:\n", - " true_patch_size = src_down.shape\n", - "\n", - "def scroll_in_z(z):\n", - " src_down = transform.downscale_local_mean(src_sample[z-1], (downscaling_in_xy,downscaling_in_xy))\n", - " tgt_down = transform.downscale_local_mean(tgt_sample[z-1], (downscaling_in_xy,downscaling_in_xy)) \n", - " if random_crop:\n", - " src_slice = src_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n", - " tgt_slice = tgt_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n", - " else:\n", - " \n", - " src_slice = transform.resize(src_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n", - " tgt_slice = transform.resize(tgt_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n", - "\n", - " f=plt.figure(figsize=(16,8))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(src_slice, cmap='gray')\n", - " plt.title('Training source (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(tgt_slice, cmap='magma')\n", - " plt.title('Training target (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - " plt.savefig('/content/TrainingDataExample_Unet3D.png',bbox_inches='tight',pad_inches=0)\n", - " #plt.close()\n", - "\n", - "print('This is what the training images will look like with the chosen settings')\n", - "interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_sample.shape[0], step=1, value=0));\n", - "\n", - "#Create a copy of an example slice and close the display.\n", - "scroll_in_z(z=int(src_sample.shape[0]/2))\n", - "plt.close()\n", - "\n", - "# Save model parameters\n", - "params = {'training_source': training_source,\n", - " 'training_target': training_target,\n", - " 'model_name': model_name,\n", - " 'model_path': model_path,\n", - " 'number_of_epochs': number_of_epochs,\n", - " 'batch_size': batch_size,\n", - " 'training_shape': training_shape,\n", - " 'downscaling': downscaling_in_xy,\n", - " 'true_patch_size': true_patch_size,\n", - " 'val_split': validation_split_in_percent/100,\n", - " 'random_crop': random_crop}\n", - "\n", - "params_df = pd.DataFrame.from_dict(params, orient='index')\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rni7MsNJwzaw" - }, - "source": [ - "## **3.2. Data augmentation**\n", - " \n", - "---\n", - " Augmenting the training data increases robustness of the model by simulating possible variations within the training data which avoids it from overfitting on small datasets. We therefore strongly recommended augmenting the data and making sure that the applied augmentations are reasonable.\n", - "\n", - "* **Gaussian blur** blurs images using Gaussian kernels with a sigma of `gaussian_sigma`. This augmentation step is applied with a probability of `gaussian_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/blur.html#gaussianblur).\n", - "\n", - "* **Linear contrast** modifies the contrast of images according to `127 + alpha *(pixel_value-127)`, where `pixel_value` and `alpha` are sampled uniformly from the interval `[contrast_min, contrast_max]`. This augmentation step is applied with a probability of `contrast_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/contrast.html#linearcontrast).\n", - "\n", - "* **Additive Gaussian noise** adds Gaussian noise sampled once per pixel from a normal distribution `N(0, s)`, where `s` is sampled from `[scale_min, scale_max]`. This augmentation step is applied with a probability of `noise_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/arithmetic.html#additivegaussiannoise).\n", - "\n", - "* **Add custom augmenters** allows you to create a custom augmentation pipeline using the [augmenters available in the imagug library](https://imgaug.readthedocs.io/en/latest/source/overview_of_augmenters.html). \n", - "![custom_augmenters](https://drive.google.com/uc?export=view&id=1LWTTxZEH-gG8T4Kb0BQ37lHWjn2Oq5K4)\n", - "In the example above, the augmentation pipeline is equivalent to: \n", - "```\n", - "seq = iaa.Sequential([\n", - " iaa.Sometimes(0.3, iaa.GammaContrast((0.5, 2.0)), \n", - " iaa.Sometimes(0.4, iaa.AverageBlur((0.5, 2.0)), \n", - " iaa.Sometimes(0.5, iaa.LinearContrast((0.4, 1.6)), \n", - "], random_order=True)\n", - "```\n", - " Note that there is no limit on the number of augmenters that can be chained together and that individual augmenter and parameter entries must be separated by `;`. Custom augmenters do not overwrite the preset augmentation steps (*Gaussian blur*, *Linear contrast* or *Additive Gaussian noise*). Also, the augmenters, augmenter parameters and augmenter frequencies must be entered such that each position within the string corresponds to the same augmentation step.\n", - "\n", - "* **`apply_data_augmentation`** ensures that data augmentation is randomly applied to the training data at each training step. This includes inverting the order of the slices within a training patch, as well as applying any augmenters that are added. *Default: True*\n", - "\n", - "* **`add_elastic_deform`** ensures that elastic grid-based deformations are applied as described in the original 3D U-Net paper. *Default: True*" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "UDy9ut0HYKLv" - }, - "source": [ - "#@markdown ##**Augmentation options**\n", - "\n", - "#@markdown ###Data augmentation\n", - "\n", - "apply_data_augmentation = False #@param {type:\"boolean\"}\n", - "\n", - "# List of augmentations\n", - "augmentations = []\n", - "\n", - "#@markdown ###Gaussian blur\n", - "add_gaussian_blur = True #@param {type:\"boolean\"}\n", - "gaussian_sigma = 0.7#@param {type:\"number\"}\n", - "gaussian_frequency = 0.5 #@param {type:\"number\"}\n", - "\n", - "if add_gaussian_blur:\n", - " augmentations.append(iaa.Sometimes(gaussian_frequency, iaa.GaussianBlur(sigma=(0, gaussian_sigma))))\n", - "\n", - "#@markdown ###Linear contrast\n", - "add_linear_contrast = True #@param {type:\"boolean\"}\n", - "contrast_min = 0.4 #@param {type:\"number\"}\n", - "contrast_max = 1.6#@param {type:\"number\"}\n", - "contrast_frequency = 0.5 #@param {type:\"number\"}\n", - "\n", - "if add_linear_contrast:\n", - " augmentations.append(iaa.Sometimes(contrast_frequency, iaa.LinearContrast((contrast_min, contrast_max))))\n", - "\n", - "#@markdown ###Additive Gaussian noise\n", - "add_additive_gaussian_noise = False #@param {type:\"boolean\"}\n", - "scale_min = 0 #@param {type:\"number\"}\n", - "scale_max = 0.05 #@param {type:\"number\"}\n", - "noise_frequency = 0.5 #@param {type:\"number\"}\n", - "\n", - "if add_additive_gaussian_noise:\n", - " augmentations.append(iaa.Sometimes(noise_frequency, iaa.AdditiveGaussianNoise(scale=(scale_min, scale_max))))\n", - "\n", - "#@markdown ###Add custom augmenters\n", - "\n", - "augmenters = \"GammaContrast; AverageBlur; LinearContrast\" #@param {type:\"string\"}\n", - "\n", - "augmenter_params = \"(0.5, 2.0); (0.5, 2.0); (0.4, 1.6)\" #@param {type:\"string\"}\n", - "\n", - "augmenter_frequency = \"0.3; 0.4; 0.5\" #@param {type:\"string\"}\n", - "\n", - "aug_lst = augmenters.split(';')\n", - "aug_params_lst = augmenter_params.split(';')\n", - "aug_freq_lst = augmenter_frequency.split(';')\n", - "\n", - "assert len(aug_lst) == len(aug_params_lst) and len(aug_lst) == len(aug_freq_lst), 'The number of arguments in augmenters, augmenter_params and augmenter_frequency are not the same!'\n", - "\n", - "for __, (aug, param, freq) in enumerate(zip(aug_lst, aug_params_lst, aug_freq_lst)):\n", - " aug, param, freq = aug.strip(), param.strip(), freq.strip() \n", - " aug_func = iaa.Sometimes(eval(freq), getattr(iaa, aug)(eval(param)))\n", - " augmentations.append(aug_func)\n", - "\n", - "#@markdown ###Elastic deformations\n", - "add_elastic_deform = True #@param {type:\"boolean\"}\n", - "sigma = 2#@param {type:\"number\"}\n", - "points = 2#@param {type:\"number\"}\n", - "order = 2#@param {type:\"number\"}\n", - "\n", - "if add_elastic_deform:\n", - " deform_params = (sigma, points, order)\n", - "else:\n", - " deform_params = None\n", - "\n", - "train_generator = MultiPageTiffGenerator(training_source,\n", - " training_target,\n", - " batch_size=batch_size,\n", - " shape=training_shape,\n", - " augment=apply_data_augmentation,\n", - " augmentations=augmentations,\n", - " deform_augment=add_elastic_deform,\n", - " deform_augmentation_params=deform_params,\n", - " val_split=validation_split_in_percent/100,\n", - " random_crop=random_crop,\n", - " downscale=downscaling_in_xy,\n", - " binary_target=binary_target)\n", - "\n", - "val_generator = MultiPageTiffGenerator(training_source,\n", - " training_target,\n", - " batch_size=batch_size,\n", - " shape=training_shape,\n", - " val_split=validation_split_in_percent/100,\n", - " is_val=True,\n", - " random_crop=random_crop,\n", - " downscale=downscaling_in_xy,\n", - " binary_target=binary_target)\n", - "\n", - "\n", - "if apply_data_augmentation:\n", - " print('Data augmentation enabled.')\n", - " sample_src_aug, sample_tgt_aug = train_generator.sample_augmentation(random.randint(0, len(train_generator)))\n", - "\n", - " def scroll_in_z(z):\n", - " f=plt.figure(figsize=(16,8))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(sample_src_aug[0,:,:,z-1,0], cmap='gray')\n", - " plt.title('Sample augmented source (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(sample_tgt_aug[0,:,:,z-1,0], cmap='magma')\n", - " plt.title('Sample training target (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " print('This is what the augmented training images will look like with the chosen settings')\n", - " interact(scroll_in_z, z=widgets.IntSlider(min=1, max=sample_src_aug.shape[3], step=1, value=0));\n", - "\n", - "else:\n", - " print('Data augmentation disabled.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rQndJj70FzfL" - }, - "source": [ - "# **4. Train the network**\n", - "---\n", - "\n", - "**CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training times must be less than 12 hours! If training takes longer than 12 hours, please decrease `number_of_epochs`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0NjxtSfqsgxx" - }, - "source": [ - "## **4.1. Show model and start training**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "opWPgUl7erct", - "cellView": "form" - }, - "source": [ - "#@markdown ## Show model summary\n", - "model.summary()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "EZnoS3rb8BSR", - "scrolled": false, - "cellView": "form" - }, - "source": [ - "#@markdown ##Start training\n", - "\n", - "\n", - "#here we check that no model with the same name already exist, if so delete\n", - "if not resume_training and os.path.exists(full_model_path): \n", - " shutil.rmtree(full_model_path)\n", - " print(bcolors.WARNING+'!! WARNING: Folder already exists and has been overwritten !!') \n", - "\n", - "if not os.path.exists(full_model_path):\n", - " os.makedirs(full_model_path)\n", - "\n", - "# Save file\n", - "params_df.to_csv(os.path.join(full_model_path, 'params.csv'))\n", - "\n", - "start = time.time()\n", - "# Start Training\n", - "model.train(epochs=number_of_epochs,\n", - " batch_size=batch_size,\n", - " train_generator=train_generator,\n", - " val_generator=val_generator,\n", - " model_path=model_path,\n", - " model_name=model_name,\n", - " loss=loss_function,\n", - " metrics=metrics,\n", - " optimizer=optimizer,\n", - " ckpt_period=checkpointing_period,\n", - " save_best_ckpt_only=save_best_only,\n", - " ckpt_path=last_ckpt_path)\n", - "\n", - "print('Training successfully completed!')\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "#Create a pdf document with training summary\n", - "\n", - "class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - "pdf = MyFPDF()\n", - "pdf.add_page()\n", - "pdf.set_right_margin(-1)\n", - "pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - "Network = 'U-Net 3D'\n", - "day = datetime.now()\n", - "datetime_str = str(day)[0:10]\n", - "\n", - "Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - "pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - "# add another cell \n", - "training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - "pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - "pdf.ln(1)\n", - "\n", - "Header_2 = 'Information for your materials and methods:'\n", - "pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - "all_packages = ''\n", - "for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "#print(all_packages)\n", - "\n", - "#Main Packages\n", - "main_packages = ''\n", - "version_numbers = []\n", - "for name in ['tensorflow','numpy','Keras']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - "cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - "cuda_version = cuda_version.stdout.decode('utf-8')\n", - "cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - "gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - "gpu_name = gpu_name.stdout.decode('utf-8')\n", - "gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - "#print(cuda_version[cuda_version.find(', V')+3:-1])\n", - "#print(gpu_name)\n", - "\n", - "shape = io.imread(training_source).shape\n", - "dataset_size = len(train_generator)\n", - "\n", - "text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - "if resume_training:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch_size: '+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - "pdf.set_font('')\n", - "pdf.set_font_size(10.)\n", - "pdf.multi_cell(190, 5, txt = text, align='L')\n", - "pdf.set_font('')\n", - "pdf.set_font('Arial', size = 10, style = 'B')\n", - "pdf.ln(1)\n", - "pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - "pdf.set_font('')\n", - "if apply_data_augmentation:\n", - " aug_text = 'The dataset was augmented by'\n", - " if add_gaussian_blur == True:\n", - " aug_text = aug_text+'\\n- gaussian blur'\n", - " if add_linear_contrast == True:\n", - " aug_text = aug_text+'\\n- linear contrast'\n", - " if add_additive_gaussian_noise == True:\n", - " aug_text = aug_text+'\\n- additive gaussian noise'\n", - " if augmenters != '':\n", - " aug_text = aug_text+'\\n- imgaug augmentations: '+augmenters\n", - " if add_elastic_deform == True:\n", - " aug_text = aug_text+'\\n- elastic deformation'\n", - "else:\n", - " aug_text = 'No augmentation was used for training.'\n", - "pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - "pdf.set_font('Arial', size = 11, style = 'B')\n", - "pdf.ln(1)\n", - "pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - "pdf.set_font('')\n", - "pdf.set_font_size(10.)\n", - "if use_default_advanced_parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - "pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - "pdf.ln(1)\n", - "html = \"\"\" \n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
batch_size{1}
patch_size{2}
image_pre_processing{3}
validation_split_in_percent{4}
downscaling_in_xy{5}
binary_target{6}
loss_function{7}
metrics{8}
optimizer{9}
checkpointing_period{10}
save_best_only{11}
resume_training{12}
\n", - "\"\"\".format(number_of_epochs,batch_size,str(patch_size[0])+'x'+str(patch_size[1])+'x'+str(patch_size[2]),image_pre_processing, validation_split_in_percent, downscaling_in_xy, str(binary_target), loss_function, metrics, optimizer, checkpointing_period, str(save_best_only), str(resume_training))\n", - "pdf.write_html(html)\n", - "\n", - "#pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - "pdf.set_font(\"Arial\", size = 11, style='B')\n", - "pdf.ln(1)\n", - "pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - "pdf.set_font('')\n", - "pdf.set_font('Arial', size = 10, style = 'B')\n", - "pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - "pdf.set_font('')\n", - "pdf.multi_cell(170, 5, txt = training_source, align = 'L')\n", - "pdf.set_font('')\n", - "pdf.set_font('Arial', size = 10, style = 'B')\n", - "pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - "pdf.set_font('')\n", - "pdf.multi_cell(170, 5, txt = training_target, align = 'L')\n", - "#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - "pdf.ln(1)\n", - "pdf.set_font('')\n", - "pdf.set_font('Arial', size = 10, style = 'B')\n", - "pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - "pdf.set_font('')\n", - "pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - "pdf.ln(1)\n", - "pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n", - "pdf.ln(1)\n", - "exp_size = io.imread('/content/TrainingDataExample_Unet3D.png').shape\n", - "pdf.image('/content/TrainingDataExample_Unet3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - "pdf.ln(1)\n", - "ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n", - "pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - "ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n", - "pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "# if Use_Data_augmentation:\n", - "# ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - "# pdf.multi_cell(190, 5, txt = ref_4, align='L')\n", - "pdf.ln(3)\n", - "reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - "pdf.set_font('Arial', size = 11, style='B')\n", - "pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - "pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n", - "\n", - "print('------------------------------')\n", - "print('PDF report exported in '+model_path+'/'+model_name+'/')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XQjQb_J_Qyku" - }, - "source": [ - "##**4.2. Download your model from Google Drive**\n", - "\n", - "---\n", - "Once training is complete, the trained model is automatically saved to your Google Drive, in the **`model_path`** folder that was specified in Section 3. Download the folder to avoid any unwanted surprises, since the data can be erased if you train another model using the same `model_path`." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "y_DtHgr-41K0", - "cellView": "form" - }, - "source": [ - "#@markdown ##Download model directory\n", - "#@markdown 1. Specify the model_path in `model_path_download` otherwise the model sepcified in Section 3.1 will be downloaded\n", - "#@markdown 2. Run this cell to zip the model directory\n", - "#@markdown 3. Download the zipped file from the *Files* tab on the left\n", - "\n", - "from google.colab import files\n", - "\n", - "model_path_download = \"\" #@param {type:\"string\"}\n", - "\n", - "if len(model_path_download) == 0:\n", - " model_path_download = full_model_path\n", - "\n", - "model_name_download = os.path.basename(model_path_download)\n", - "\n", - "print('Zipping', model_name_download)\n", - "\n", - "zip_model_path = model_name_download + '.zip'\n", - "\n", - "!zip -r \"$zip_model_path\" \"$model_path_download\"\n", - "\n", - "print('Successfully saved zipped model directory as', zip_model_path)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2HbZd7rFqAad" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "In this section the newly trained model can be assessed for performance. This involves inspecting the loss function in Section 5.1. and employing more advanced metrics in Section 5.2.\n", - "\n", - "**We highly recommend performing quality control on all newly trained models.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "EdcnkCr9Nbl8", - "cellView": "form" - }, - "source": [ - "#@markdown ###Model to be evaluated:\n", - "#@markdown If left blank, the latest model defined in Section 3 will be evaluated:\n", - "\n", - "qc_model_name = \"\" #@param {type:\"string\"}\n", - "qc_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if len(qc_model_path) == 0 and len(qc_model_name) == 0:\n", - " qc_model_name = model_name\n", - " qc_model_path = model_path\n", - "\n", - "full_qc_model_path = os.path.join(qc_model_path, qc_model_name)\n", - "\n", - "if os.path.exists(full_qc_model_path):\n", - " print(qc_model_name + ' will be evaluated')\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yDY9dtzdUTLh" - }, - "source": [ - "## **5.1. Inspecting loss function**\n", - "---\n", - "\n", - "**The training loss** is the error between prediction and target after each epoch calculated across the training data while the **validation loss** calculates the error on the (unseen) validation data. During training these values should decrease until converging at which point the model has been sufficiently trained. If the validation loss starts increasing while the training loss has plateaued, the model has overfit on the training data which reduces its ability to generalise. Aim to halt training before this point.\n", - "\n", - "**Note:** For a more in-depth explanation please refer to [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols et al.\n", - "\n", - "\n", - "The accuracy is another performance metric that is calculated after each epoch. We use the [Sørensen–Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) to score the prediction accuracy. \n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vMzSP50kMv5p", - "cellView": "form" - }, - "source": [ - "#@markdown ##Visualise loss and accuracy\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "accuracyDataFromCSV = []\n", - "valaccuracyDataFromCSV = []\n", - "\n", - "with open(full_qc_model_path + '/Quality Control/training_evaluation.csv', 'r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[2]))\n", - " vallossDataFromCSV.append(float(row[4]))\n", - " accuracyDataFromCSV.append(float(row[1]))\n", - " valaccuracyDataFromCSV.append(float(row[3]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training and validation loss', fontsize=14)\n", - "plt.ylabel('Loss', fontsize=12)\n", - "plt.xlabel('Epochs', fontsize=12)\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.plot(epochNumber,accuracyDataFromCSV, label='Training accuracy')\n", - "plt.plot(epochNumber,valaccuracyDataFromCSV, label='Validation accuracy')\n", - "plt.title('Training and validation accuracy', fontsize=14)\n", - "plt.ylabel('Dice', fontsize=12)\n", - "plt.xlabel('Epochs', fontsize=12)\n", - "plt.legend()\n", - "plt.savefig(full_qc_model_path + '/Quality Control/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n", - "plt.show()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RZOPCVN0qcYb" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "This section will provide both a visual indication of the model performance by comparing the overlay of the predicted and source volume." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "XbL7T9bw98Ja", - "cellView": "form" - }, - "source": [ - "#@markdown ##Compare prediction and ground-truth on testing data\n", - "\n", - "#@markdown Provide an unseen annotated dataset to determine the performance of the model:\n", - "\n", - "testing_source = \"\" #@param{type:\"string\"}\n", - "testing_target = \"\" #@param{type:\"string\"}\n", - "\n", - "qc_dir = full_qc_model_path + '/Quality Control'\n", - "predict_dir = qc_dir + '/Prediction'\n", - "if os.path.exists(predict_dir):\n", - " shutil.rmtree(predict_dir)\n", - "\n", - "os.makedirs(predict_dir)\n", - "\n", - "# predict_dir + '/' + \n", - "predict_path = os.path.splitext(os.path.basename(testing_source))[0] + '_prediction.tif'\n", - "\n", - "def last_chars(x):\n", - " return(x[-11:])\n", - "\n", - "try:\n", - " ckpt_dir_list = glob(full_qc_model_path + '/ckpt/*')\n", - " ckpt_dir_list.sort(key=last_chars)\n", - " last_ckpt_path = ckpt_dir_list[0]\n", - " print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n", - "except IndexError:\n", - " raise CheckpointError('No previous checkpoints were found, please retrain model.')\n", - "\n", - "# Load parameters\n", - "params = pd.read_csv(os.path.join(full_qc_model_path, 'params.csv'), names=['val'], header=0, index_col=0) \n", - "\n", - "model = Unet3D(shape=params.loc['training_shape', 'val'])\n", - "\n", - "prediction = model.predict(testing_source, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - "\n", - "tifffile.imwrite(predict_path, prediction.astype('float32'), imagej=True)\n", - "\n", - "print('Predicted images!')\n", - "\n", - "qc_metrics_path = full_qc_model_path + '/Quality Control/QC_metrics_' + qc_model_name + '.csv'\n", - "\n", - "test_target = tifffile.imread(testing_target)\n", - "test_source = tifffile.imread(testing_source)\n", - "test_prediction = tifffile.imread(predict_path)\n", - "\n", - "def scroll_in_z(z):\n", - "\n", - " plt.figure(figsize=(25,5))\n", - " # Source\n", - " plt.subplot(1,4,1)\n", - " plt.axis('off')\n", - " plt.imshow(test_source[z-1], cmap='gray')\n", - " plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n", - "\n", - " # Target (Ground-truth)\n", - " plt.subplot(1,4,2)\n", - " plt.axis('off')\n", - " plt.imshow(test_target[z-1], cmap='magma')\n", - " plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n", - "\n", - " # Prediction\n", - " plt.subplot(1,4,3)\n", - " plt.axis('off')\n", - " plt.imshow(test_prediction[z-1], cmap='magma')\n", - " plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n", - " \n", - " # Overlay\n", - " plt.subplot(1,4,4)\n", - " plt.axis('off')\n", - " plt.imshow(test_target[z-1], cmap='Greens')\n", - " plt.imshow(test_prediction[z-1], alpha=0.5, cmap='Purples')\n", - " plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n", - " plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n", - "interact(scroll_in_z, z=widgets.IntSlider(min=1, max=test_source.shape[0], step=1, value=0));" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lIP7AOvkg5pT" - }, - "source": [ - "## **5.3. Determine best Intersection over Union and threshold**\n", - "---\n", - "\n", - "**Note:** This section is only relevant if the target image is a binary mask and `binary_target` is selected in Section 3! \n", - "\n", - "This section will provide both a visual and a quantitative indication of the model performance by comparing the overlay of the predicted and source volume, as well as computing the highest [**Intersection over Union**](https://en.wikipedia.org/wiki/Jaccard_index) (IoU) score. The IoU is also known as the Jaccard Index. \n", - "\n", - "The best threshold is calculated using the IoU. Each threshold value from 0 to 255 is tested and the threshold with the highest score is deemed the best. The IoU is calculated for the entire volume in 3D." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "1hXoooMbYvxl", - "cellView": "form" - }, - "source": [ - "\n", - "#@markdown ##Calculate Intersection over Union and best threshold \n", - "prediction = tifffile.imread(predict_path)\n", - "prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - "\n", - "target = tifffile.imread(testing_target).astype(np.bool)\n", - "\n", - "def iou_vs_threshold(prediction, target):\n", - " threshold_list = []\n", - " IoU_scores_list = []\n", - "\n", - " for threshold in range(0,256): \n", - " mask = prediction > threshold\n", - "\n", - " intersection = np.logical_and(target, mask)\n", - " union = np.logical_or(target, mask)\n", - " iou_score = np.sum(intersection) / np.sum(union)\n", - "\n", - " threshold_list.append(threshold)\n", - " IoU_scores_list.append(iou_score)\n", - "\n", - " return threshold_list, IoU_scores_list\n", - "\n", - "threshold_list, IoU_scores_list = iou_vs_threshold(prediction, target)\n", - "thresh_arr = np.array(list(zip(threshold_list, IoU_scores_list)))\n", - "best_thresh = int(np.where(thresh_arr == np.max(thresh_arr[:,1]))[0])\n", - "best_iou = IoU_scores_list[best_thresh]\n", - "\n", - "print('Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh))\n", - "\n", - "def adjust_threshold(threshold, z):\n", - "\n", - " f=plt.figure(figsize=(25,5))\n", - " plt.subplot(1,4,1)\n", - " plt.imshow((prediction[z-1] > threshold).astype('uint8'), cmap='magma')\n", - " plt.title('Prediction (Threshold = ' + str(threshold) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,4,2)\n", - " plt.imshow(target[z-1], cmap='magma')\n", - " plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,4,3)\n", - " plt.axis('off')\n", - " plt.imshow(test_source[z-1], cmap='gray')\n", - " plt.imshow((prediction[z-1] > threshold).astype('uint8'), alpha=0.4, cmap='Reds')\n", - " plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n", - "\n", - " plt.subplot(1,4,4)\n", - " plt.title('Threshold vs. IoU', fontsize=15)\n", - " plt.plot(threshold_list, IoU_scores_list)\n", - " plt.plot(threshold, IoU_scores_list[threshold], 'ro') \n", - " plt.ylabel('IoU score')\n", - " plt.xlabel('Threshold')\n", - " plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png',bbox_inches=matplotlib.transforms.Bbox([[17.5,0],[23,5]]),pad_inches=0)\n", - " plt.show()\n", - "\n", - "interact(adjust_threshold, \n", - " threshold=widgets.IntSlider(min=0, max=255, step=1, value=best_thresh),\n", - " z=widgets.IntSlider(min=1, max=prediction.shape[0], step=1, value=0));\n", - "\n", - "#Make a pdf summary of the QC results\n", - "\n", - "\n", - "class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - "pdf = MyFPDF()\n", - "pdf.add_page()\n", - "pdf.set_right_margin(-1)\n", - "pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - "Network = 'U-Net 3D'\n", - "\n", - "day = datetime.now()\n", - "datetime_str = str(day)[0:10]\n", - "\n", - "Header = 'Quality Control report for '+Network+' model ('+qc_model_name+')\\nDate: '+datetime_str\n", - "pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - "all_packages = ''\n", - "for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - "pdf.set_font('')\n", - "pdf.set_font('Arial', size = 11, style = 'B')\n", - "pdf.ln(2)\n", - "pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n", - "pdf.ln(1)\n", - "if os.path.exists(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png'):\n", - " exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png').shape\n", - " pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - "else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n", - "pdf.ln(2)\n", - "pdf.set_font('')\n", - "pdf.set_font('Arial', size = 10, style = 'B')\n", - "pdf.ln(3)\n", - "pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - "pdf.ln(1)\n", - "exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png').shape\n", - "pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - "pdf.ln(1)\n", - "pdf.set_font('')\n", - "pdf.set_font('Arial', size = 11, style = 'B')\n", - "pdf.ln(1)\n", - "pdf.cell(180, 5, txt = 'IoU threshold optimisation', align='L', ln=1)\n", - "pdf.set_font('')\n", - "pdf.set_font_size(10.)\n", - "pdf.ln(1)\n", - "pdf.cell(120, 5, txt='Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh), align='L', ln=1)\n", - "pdf.ln(2)\n", - "exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png').shape\n", - "pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png', x=16, y=None, w = round(exp_size[1]/6), h = round(exp_size[0]/6))\n", - "pdf.ln(1)\n", - "pdf.set_font('')\n", - "pdf.set_font_size(10.)\n", - "ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n", - "pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - "ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n", - "pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - "pdf.ln(3)\n", - "reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - "pdf.set_font('Arial', size = 11, style='B')\n", - "pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - "pdf.output(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/'+qc_model_name+'_QC_report.pdf')\n", - "\n", - "print('------------------------------')\n", - "print('QC PDF report exported in '+os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/')\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Esqnbew8uznk" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "Once sufficient performance of the trained model has been established using Section 5, the network can be used to segment unseen volumetric data." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d8wuQGjoq6eN" - }, - "source": [ - "## **6.1. Generate predictions from unseen dataset**\n", - "---\n", - "\n", - "The most recently trained model can now be used to predict segmentation masks on unseen images. If you want to use an older model, leave `model_path` blank. Predicted output images are saved in `output_path` as Image-J compatible TIFF files.\n", - "\n", - "## **Prediction parameters**\n", - "\n", - "* **`source_path`** specifies the location of the source \n", - "image volume.\n", - "\n", - "* **`output_directory`** specified the directory where the output predictions are stored.\n", - "\n", - "* **`binary_target`** should be chosen if the network is trained to predict binary segmentation masks.\n", - "\n", - "* **`threshold`** can be calculated in Section 5 and is used to generate binary masks from the predictions.\n", - "\n", - "* **`big_tiff`** should be chosen if the expected prediction exceeds 4GB. The predictions will be saved using the BigTIFF format. Beware that this might substantially reduce the prediction speed. *Default: False* \n", - "\n", - "* **`prediction_depth`** is only relevant if the prediction is saved as a BigTIFF. The prediction will not be performed in one go to not deplete the memory resources. Instead, the prediction is iteratively performed on a subset of the entire volume with shape `(source.shape[0], source.shape[1], prediction_depth)`. *Default: 32*\n", - "\n", - "* **`model_path`** specifies the path to a model other than the most recently trained." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Ps4bbZgkmV8V", - "cellView": "form" - }, - "source": [ - "#@markdown ## Download example volume\n", - "\n", - "#@markdown This can take up to an hour\n", - "\n", - "import requests \n", - "import os\n", - "from tqdm.notebook import tqdm \n", - "\n", - "\n", - "def download_from_url(url, save_as):\n", - " file_url = url\n", - " r = requests.get(file_url, stream=True) \n", - " \n", - " with open(save_as, 'wb') as file: \n", - " for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=3275073, ncols=1000):\n", - " if block:\n", - " file.write(block) \n", - "\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/volumedata.tif', 'example_dataset/volumedata.tif')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "8oQr1yKyBwZS", - "cellView": "form" - }, - "source": [ - "#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then run the cell to predict outputs from your unseen images.\n", - "\n", - "source_path = \"\" #@param {type:\"string\"}\n", - "output_directory = \"\" #@param {type:\"string\"}\n", - "\n", - "if not os.path.exists(output_directory):\n", - " os.makedirs(output_directory)\n", - "\n", - "output_path = os.path.join(output_directory, os.path.splitext(os.path.basename(source_path))[0] + '_predicted.tif')\n", - "#@markdown ###Prediction parameters:\n", - "\n", - "binary_target = True #@param {type:\"boolean\"}\n", - "\n", - "save_probability_map = False #@param {type:\"boolean\"}\n", - "\n", - "#@markdown Determine best threshold in Section 5.2.\n", - "\n", - "use_calculated_threshold = True #@param {type:\"boolean\"}\n", - "threshold = 200#@param {type:\"number\"}\n", - "\n", - "# Tifffile library issues means that images cannot be appended to \n", - "#@markdown Choose if prediction file exceeds 4GB or if input file is very large (above 2GB). Image volume saved as BigTIFF.\n", - "big_tiff = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown Reduce `prediction_depth` if runtime runs out of memory during prediction. Only relevant if prediction saved as BigTIFF\n", - "\n", - "prediction_depth = 32#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Model to be evaluated\n", - "#@markdown If left blank, the latest model defined in Section 5 will be evaluated\n", - "\n", - "full_model_path_ = \"\" #@param {type:\"string\"}\n", - "\n", - "if len(full_model_path_) == 0:\n", - " full_model_path_ = os.path.join(qc_model_path, qc_model_name) \n", - "\n", - "\n", - "\n", - "# Load parameters\n", - "params = pd.read_csv(os.path.join(full_model_path_, 'params.csv'), names=['val'], header=0, index_col=0) \n", - "model = Unet3D(shape=params.loc['training_shape', 'val'])\n", - "\n", - "if use_calculated_threshold:\n", - " threshold = best_thresh\n", - "\n", - "def last_chars(x):\n", - " return(x[-11:])\n", - "\n", - "try:\n", - " ckpt_dir_list = glob(full_model_path_ + '/ckpt/*')\n", - " ckpt_dir_list.sort(key=last_chars)\n", - " last_ckpt_path = ckpt_dir_list[0]\n", - " print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n", - "except IndexError:\n", - " raise CheckpointError('No previous checkpoints were found, please retrain model.')\n", - "\n", - "src = tifffile.imread(source_path)\n", - "\n", - "if src.nbytes >= 4e9:\n", - " big_tiff = True\n", - " print('The source file exceeds 4GB in memory, prediction will be saved as BigTIFF!')\n", - "\n", - "if binary_target:\n", - " if not big_tiff:\n", - " prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " prediction = (prediction > threshold).astype('float32')\n", - "\n", - " tifffile.imwrite(output_path, prediction, imagej=True)\n", - "\n", - " else:\n", - " with tifffile.TiffWriter(output_path, bigtiff=True) as tif:\n", - " for i in tqdm(range(0, src.shape[0], prediction_depth)):\n", - " prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " prediction = (prediction > threshold).astype('float32')\n", - " \n", - " for j in range(prediction.shape[0]):\n", - " tif.save(prediction[j])\n", - "\n", - "if not binary_target or save_probability_map:\n", - " if not binary_target:\n", - " prob_map_path = output_path\n", - " else:\n", - " prob_map_path = os.path.splitext(output_path)[0] + '_prob_map.tif'\n", - " \n", - " if not big_tiff:\n", - " prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " tifffile.imwrite(prob_map_path, prediction.astype('float32'), imagej=True)\n", - "\n", - " else:\n", - " with tifffile.TiffWriter(prob_map_path, bigtiff=True) as tif:\n", - " for i in tqdm(range(0, src.shape[0], prediction_depth)):\n", - " prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " \n", - " for j in range(prediction.shape[0]):\n", - " tif.save(prediction[j])\n", - "\n", - "print('Predictions saved as', output_path)\n", - "\n", - "src_volume = tifffile.imread(source_path)\n", - "pred_volume = tifffile.imread(output_path)\n", - "\n", - "def scroll_in_z(z):\n", - " \n", - " f=plt.figure(figsize=(25,5))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(src_volume[z-1], cmap='gray')\n", - " plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(pred_volume[z-1], cmap='magma')\n", - " plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - "interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_volume.shape[0], step=1, value=0));\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.2. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Rn9zpWpo0xNw" - }, - "source": [ - "\n", - "#**Thank you for using 3D U-Net!**" - ] - } - ] -} +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **U-Net (3D)**\n"," ---\n","\n"," The 3D U-Net was first introduced by [Çiçek et al](https://arxiv.org/abs/1606.06650) for learning dense volumetric segmentations from sparsely annotated ground-truth data building upon the original U-Net architecture by [Ronneberger et al](https://arxiv.org/abs/1505.04597). \n","\n","**This particular implementation allows supervised learning between any two types of 3D image data. If you are interested in image segmentation of 2D datasets, you should use the 2D U-Net notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project ([ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki)) jointly developed by the [Jacquemet](https://cellmig.org/) and [Henriques](https://henriqueslab.github.io/) laboratories and created by Daniel Krentzel.\n","\n","This notebook is laregly based on the following paper: \n","\n","[**3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation**](https://arxiv.org/pdf/1606.06650.pdf) by Özgün Çiçek *et al.* published on arXiv in 2016\n","\n","The following two Python libraries play an important role in the notebook: \n","\n","1. [**Elasticdeform**](https://github.com/gvtulder/elasticdeform)\n"," by Gijs van Tulder was used to augment the 3D training data using elastic grid-based deformations as described in the original 3D U-Net paper. \n","\n","2. [**Tifffile**](https://github.com/cgohlke/tifffile) by Christoph Gohlke is a great library for reading and writing TIFF files. \n","\n","3. [**Imgaug**](https://github.com/aleju/imgaug) by Alexander Jung *et al.* is an amazing library for image augmentation in machine learning - it is the most complete and extensive image augmentation package I have found to date. \n","\n","The [example dataset](https://www.epfl.ch/labs/cvlab/data/data-em/) represents a 5x5x5µm section taken from the CA1 hippocampus region of the brain with annotated mitochondria and was acquired by Graham Knott and Marco Cantoni at EPFL.\n","\n","\n","**Please also cite the original paper and relevant Python libraries when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cells: \n","\n","**Text cells** provide information and can be modified by double-clicking the cell. You are currently reading a text cell. You can create a new one by clicking `+ Text`.\n","\n","**Code cells** contain code which can be modfied by selecting the cell. To execute the cell, move your cursor to the `[]`-symbol on the left side of the cell (a play button should appear). Click it to execute the cell. Once the cell is fully executed, the animation stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","Three tabs are located on the upper left side of the notebook:\n","\n","1. *Table of contents* contains the structure of the notebook. Click the headers to move quickly between sections.\n","\n","2. *Code snippets* provides a wide array of example code specific to Google Colab. You can ignore this when using this notebook.\n","\n","3. *Files* displays the current working directory. We will mount your Google Drive in Section 1.2. so that you can access your files and save them permanently.\n","\n","**Important:** All uploaded files are purged once the runtime ends.\n","\n","**Note:** The directory *sample data* in *Files* contains default files. Do not upload anything there!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive by clicking *File* -> *Save a copy in Drive*.\n","\n","To **edit a cell**, double click on the text. This will either display the source code (in code cells) or the [markdown](https://colab.research.google.com/notebooks/markdown_guide.ipynb#scrollTo=70pYkR9LiOV0) (in text cells).\n","You can use `#` in code cells to comment out parts of the code. This allows you to keep the original piece of code while not executing it."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n","\n","As the network operates in three dimensions, certain consideration should be given to correctly pre-processing the data. Ensure that the structure of interest does not substantially change between slices - image volumes with isotropic pixelsizes are ideal for this architecture.\n","\n","Each image volume must be provided as an **8-bit** or **binary multipage TIFF file** to maintain the correct ordering of individual image slices. If more than one image volume has been annotated, source and target files must be named identically and placed in separate directories. In case only one image volume has been annotated, source and target file do not have to be placed in separate directories and can be named differently, as long as their paths are explicitly provided in Section 3. \n","\n","**Prepare two datasets** (*training* and *testing*) for quality control puproses. Make sure that the *testing* dataset does not overlap with the *training* dataset and is ideally sourced from a different acquisiton and sample to ensure robustness of the trained model. \n","\n","\n","---\n","\n","\n","### **Directory structure**\n","\n","Make sure to adhere to one of the following directory structures. If only one annotated training volume exists, choose the first structure. In case more than one training volume is available, choose the second structure.\n","\n","**Structure 1:** Only one training volume\n","```\n","path/to/directory/with/one/training/volume\n","│--training_source.tif\n","│--training_target.tif\n","| \n","│--testing_source.tif\n","|--testing_target.tif \n","|\n","|--data_to_predict_on.tif\n","|--prediction_results.tif\n","\n","```\n","**Structure 2:** Various training volumes\n","```\n","path/to/directory/with/various/training/volumes\n","│--testing_source.tif\n","|--testing_target.tif \n","|\n","└───training\n","| └───source\n","| | |--training_volume_one.tif\n","| | |--training_volume_two.tif\n","| | |--...\n","| | |--training_volume_n.tif\n","| |\n","| └───target\n","| |--training_volume_one.tif\n","| |--training_volume_two.tif\n","| |--...\n","| |--training_volume_n.tif\n","|\n","|--data_to_predict_on.tif\n","|--prediction_results.tif\n","```\n","**Note:** Naming directories is completely up to you, as long as the paths are correctly specified throughout the notebook.\n","\n","\n","---\n","\n","\n","### **Important note**\n","\n","* If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do so), you will need to run **Sections 1 - 4**, then use **Section 5** to assess the quality of your model and **Section 6** to run predictions using the model that you trained.\n","\n","* If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 5** to assess the quality of your model.\n","\n","* If you only wish to **Run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","id":"M-GZMaL7pd8a"},"source":["#@markdown ##**Download example dataset**\n","\n","#@markdown This usually takes a few minutes. The images are saved in *example_dataset*.\n","\n","import requests \n","import os\n","from tqdm.notebook import tqdm \n","\n","def make_directory(dir):\n"," if not os.path.exists(dir):\n"," os.makedirs(dir)\n","\n","def download_from_url(url, save_as):\n"," file_url = url\n"," r = requests.get(file_url, stream=True) \n"," \n"," with open(save_as, 'wb') as file: \n"," for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=126875, ncols=1000):\n"," if block:\n"," file.write(block) \n","\n","\n","make_directory('example_dataset')\n","\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training.tif', 'example_dataset/training.tif')\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training_groundtruth.tif', 'example_dataset/training_groundtruth.tif')\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing.tif', 'example_dataset/testing.tif')\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing_groundtruth.tif', 'example_dataset/testing_groundtruth.tif')\n","\n","print('Example dataset successfully downloaded!')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount Google Drive**\n","---\n"," To use this notebook with your **own data**, place it in a folder on **Google Drive** following one of the directory structures outlined in **Section 0**.\n","\n","1. **Run** the **cell** below to mount your Google Drive and follow the link. \n","\n","2. **Sign in** to your Google account and press 'Allow'. \n","\n","3. Next, copy the **authorization code**, paste it into the cell and press enter. This will allow Colab to read and write data from and to your Google Drive. \n","\n","4. Once this is done, your data can be viewed in the **Files tab** on the top left of the notebook after hitting 'Refresh'."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"zxELU7CIp4oF"},"source":["#@markdown ##Unzip pre-trained model directory\n","\n","#@markdown 1. Upload a zipped model directory using the *Files* tab\n","#@markdown 2. Run this cell to unzip your model file\n","#@markdown 3. The model directory will appear in the *Files* tab \n","\n","from google.colab import files\n","\n","zipped_model_file = \"\" #@param {type:\"string\"}\n","\n","!unzip \"$zipped_model_file\""],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install 3D U-Net dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["#@markdown ##Install dependencies and instantiate network\n","Notebook_version = ['1.12']\n","#Put the imported code and libraries here\n","!pip install fpdf\n","from __future__ import absolute_import, division, print_function, unicode_literals\n","\n","try:\n"," import elasticdeform\n","except:\n"," !pip install elasticdeform\n"," import elasticdeform\n","\n","try:\n"," import tifffile\n","except:\n"," !pip install tifffile\n"," import tifffile\n","\n","try:\n"," import imgaug.augmenters as iaa\n","except:\n"," !pip install imgaug\n"," import imgaug.augmenters as iaa\n","\n","import os\n","import csv\n","import random\n","import h5py\n","import imageio\n","import math\n","import shutil\n","\n","import pandas as pd\n","from glob import glob\n","from tqdm import tqdm\n","\n","from skimage import transform\n","from skimage import exposure\n","from skimage import color\n","from skimage import io\n","\n","from scipy.ndimage import zoom\n","\n","import matplotlib.pyplot as plt\n","\n","import numpy as np\n","import tensorflow as tf\n","\n","from keras import backend as K\n","\n","from keras.layers import Conv3D\n","from keras.layers import BatchNormalization\n","from keras.layers import ReLU\n","from keras.layers import MaxPooling3D\n","from keras.layers import Conv3DTranspose\n","from keras.layers import Input\n","from keras.layers import Concatenate\n","\n","from keras.models import Model\n","\n","from keras.utils import Sequence\n","\n","from keras.callbacks import ModelCheckpoint\n","from keras.callbacks import CSVLogger\n","from keras.callbacks import Callback\n","\n","from keras.metrics import RootMeanSquaredError\n","\n","from ipywidgets import interact\n","from ipywidgets import interactive\n","from ipywidgets import fixed\n","from ipywidgets import interact_manual \n","import ipywidgets as widgets\n","\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","import time\n","\n","from skimage import io\n","import matplotlib\n","\n","print(\"Dependencies installed and imported.\")\n","\n","# Define MultiPageTiffGenerator class\n","class MultiPageTiffGenerator(Sequence):\n","\n"," def __init__(self,\n"," source_path,\n"," target_path,\n"," batch_size=1,\n"," shape=(128,128,32,1),\n"," augment=False,\n"," augmentations=[],\n"," deform_augment=False,\n"," deform_augmentation_params=(5,3,4),\n"," val_split=0.2,\n"," is_val=False,\n"," random_crop=True,\n"," downscale=1,\n"," binary_target=False):\n","\n"," # If directory with various multi-page tiffiles is provided read as list\n"," if os.path.isfile(source_path):\n"," self.dir_flag = False\n"," self.source = tifffile.imread(source_path)\n"," if binary_target:\n"," self.target = tifffile.imread(target_path).astype(np.bool)\n"," else:\n"," self.target = tifffile.imread(target_path)\n","\n"," elif os.path.isdir(source_path):\n"," self.dir_flag = True\n"," self.source_dir_list = glob(os.path.join(source_path, '*'))\n"," self.target_dir_list = glob(os.path.join(target_path, '*'))\n","\n"," self.source_dir_list.sort()\n"," self.target_dir_list.sort()\n","\n"," self.shape = shape\n"," self.batch_size = batch_size\n"," self.augment = augment\n"," self.val_split = val_split\n"," self.is_val = is_val\n"," self.random_crop = random_crop\n"," self.downscale = downscale\n"," self.binary_target = binary_target\n"," self.deform_augment = deform_augment\n"," self.on_epoch_end()\n"," \n"," if self.augment:\n"," # pass list of augmentation functions \n"," self.seq = iaa.Sequential(augmentations, random_order=True) # apply augmenters in random order\n"," if self.deform_augment:\n"," self.deform_sigma, self.deform_points, self.deform_order = deform_augmentation_params\n","\n"," def __len__(self):\n"," # If various multi-page tiff files provided sum all images within each\n"," if self.augment:\n"," augment_factor = 4\n"," else:\n"," augment_factor = 1\n"," \n"," if self.dir_flag:\n"," num_of_imgs = 0\n"," for tiff_path in self.source_dir_list:\n"," num_of_imgs += tifffile.imread(tiff_path).shape[0]\n"," xy_shape = tifffile.imread(self.source_dir_list[0]).shape[1:]\n","\n"," if self.is_val:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = xy_shape[0] * xy_shape[1] * self.val_split * num_of_imgs\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n"," else:\n"," return math.floor(self.val_split * num_of_imgs / self.batch_size)\n"," else:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = xy_shape[0] * xy_shape[1] * (1 - self.val_split) * num_of_imgs\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n","\n"," else:\n"," return math.floor(augment_factor*(1 - self.val_split) * num_of_imgs/self.batch_size)\n"," else:\n"," if self.is_val:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = self.source.shape[0] * self.source.shape[1] * self.val_split * self.source.shape[2]\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n"," else:\n"," return math.floor((self.val_split * self.source.shape[0] / self.batch_size))\n"," else:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = self.source.shape[0] * self.source.shape[1] * (1 - self.val_split) * self.source.shape[2]\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n"," else:\n"," return math.floor(augment_factor * (1 - self.val_split) * self.source.shape[0] / self.batch_size)\n","\n"," def __getitem__(self, idx):\n"," source_batch = np.empty((self.batch_size,\n"," self.shape[0],\n"," self.shape[1],\n"," self.shape[2],\n"," self.shape[3]))\n"," target_batch = np.empty((self.batch_size,\n"," self.shape[0],\n"," self.shape[1],\n"," self.shape[2],\n"," self.shape[3]))\n","\n"," for batch in range(self.batch_size):\n"," # Modulo operator ensures IndexError is avoided\n"," stack_start = self.batch_list[(idx+batch*self.shape[2])%len(self.batch_list)]\n","\n"," if self.dir_flag:\n"," self.source = tifffile.imread(self.source_dir_list[stack_start[0]])\n"," if self.binary_target:\n"," self.target = tifffile.imread(self.target_dir_list[stack_start[0]]).astype(np.bool)\n"," else:\n"," self.target = tifffile.imread(self.target_dir_list[stack_start[0]])\n","\n"," src_list = []\n"," tgt_list = []\n"," for i in range(stack_start[1], stack_start[1]+self.shape[2]):\n"," src = self.source[i]\n"," src = transform.downscale_local_mean(src, (self.downscale, self.downscale))\n"," if not self.random_crop:\n"," src = transform.resize(src, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n"," src = self._min_max_scaling(src)\n"," src_list.append(src)\n","\n"," tgt = self.target[i]\n"," tgt = transform.downscale_local_mean(tgt, (self.downscale, self.downscale))\n"," if not self.random_crop:\n"," tgt = transform.resize(tgt, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n"," if not self.binary_target:\n"," tgt = self._min_max_scaling(tgt)\n"," tgt_list.append(tgt)\n","\n"," if self.random_crop:\n"," if src.shape[0] == self.shape[0]:\n"," x_rand = 0\n"," if src.shape[1] == self.shape[1]:\n"," y_rand = 0\n"," if src.shape[0] > self.shape[0]:\n"," x_rand = np.random.randint(src.shape[0] - self.shape[0])\n"," if src.shape[1] > self.shape[1]:\n"," y_rand = np.random.randint(src.shape[1] - self.shape[1])\n"," if src.shape[0] < self.shape[0] or src.shape[1] < self.shape[1]:\n"," raise ValueError('Patch shape larger than (downscaled) source shape')\n"," \n"," for i in range(self.shape[2]):\n"," if self.random_crop:\n"," src = src_list[i]\n"," tgt = tgt_list[i]\n"," src_crop = src[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n"," tgt_crop = tgt[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n"," else:\n"," src_crop = src_list[i]\n"," tgt_crop = tgt_list[i]\n","\n"," source_batch[batch,:,:,i,0] = src_crop\n"," target_batch[batch,:,:,i,0] = tgt_crop\n","\n"," if self.augment:\n"," # On-the-fly data augmentation\n"," source_batch, target_batch = self.augment_volume(source_batch, target_batch)\n","\n"," # Data augmentation by reversing stack\n"," if np.random.random() > 0.5:\n"," source_batch, target_batch = source_batch[::-1], target_batch[::-1]\n"," \n"," # Data augmentation by elastic deformation\n"," if np.random.random() > 0.5 and self.deform_augment:\n"," source_batch, target_batch = self.deform_volume(source_batch, target_batch)\n"," \n"," if not self.binary_target:\n"," target_batch = self._min_max_scaling(target_batch)\n"," \n"," return self._min_max_scaling(source_batch), target_batch\n"," \n"," else:\n"," return source_batch, target_batch\n","\n"," def on_epoch_end(self):\n"," # Validation split performed here\n"," self.batch_list = []\n"," # Create batch_list of all combinations of tifffile and stack position\n"," if self.dir_flag:\n"," for i in range(len(self.source_dir_list)):\n"," num_of_pages = tifffile.imread(self.source_dir_list[i]).shape[0]\n"," if self.is_val:\n"," start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n"," for j in range(start_page, num_of_pages-self.shape[2]):\n"," self.batch_list.append([i, j])\n"," else:\n"," last_page = math.floor((1-self.val_split)*num_of_pages)\n"," for j in range(last_page-self.shape[2]):\n"," self.batch_list.append([i, j])\n"," else:\n"," num_of_pages = self.source.shape[0]\n"," if self.is_val:\n"," start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n"," for j in range(start_page, num_of_pages-self.shape[2]):\n"," self.batch_list.append([0, j])\n","\n"," else:\n"," last_page = math.floor((1-self.val_split)*num_of_pages)\n"," for j in range(last_page-self.shape[2]):\n"," self.batch_list.append([0, j])\n"," \n"," if self.is_val and (len(self.batch_list) <= 0):\n"," raise ValueError('validation_split too small! Increase val_split or decrease z-depth')\n"," random.shuffle(self.batch_list)\n"," \n"," def _min_max_scaling(self, data):\n"," n = data - np.min(data)\n"," d = np.max(data) - np.min(data) \n"," \n"," return n/d\n"," \n"," def class_weights(self):\n"," ones = 0\n"," pixels = 0\n","\n"," if self.dir_flag:\n"," for i in range(len(self.target_dir_list)):\n"," tgt = tifffile.imread(self.target_dir_list[i]).astype(np.bool)\n"," ones += np.sum(tgt)\n"," pixels += tgt.shape[0]*tgt.shape[1]*tgt.shape[2]\n"," else:\n"," ones = np.sum(self.target)\n"," pixels = self.target.shape[0]*self.target.shape[1]*self.target.shape[2]\n"," p_ones = ones/pixels\n"," p_zeros = 1-p_ones\n","\n"," # Return swapped probability to increase weight of unlikely class\n"," return p_ones, p_zeros\n","\n"," def deform_volume(self, src_vol, tgt_vol):\n"," [src_dfrm, tgt_dfrm] = elasticdeform.deform_random_grid([src_vol, tgt_vol],\n"," axis=(1, 2, 3),\n"," sigma=self.deform_sigma,\n"," points=self.deform_points,\n"," order=self.deform_order)\n"," if self.binary_target:\n"," tgt_dfrm = tgt_dfrm > 0.1\n"," \n"," return self._min_max_scaling(src_dfrm), tgt_dfrm \n","\n"," def augment_volume(self, src_vol, tgt_vol):\n"," src_vol_aug = np.empty(src_vol.shape)\n"," tgt_vol_aug = np.empty(tgt_vol.shape)\n","\n"," for i in range(src_vol.shape[3]):\n"," src_vol_aug[:,:,:,i,0], tgt_vol_aug[:,:,:,i,0] = self.seq(images=src_vol[:,:,:,i,0].astype('float16'), \n"," segmentation_maps=tgt_vol[:,:,:,i,0].astype(bool))\n"," return self._min_max_scaling(src_vol_aug), tgt_vol_aug\n","\n"," def sample_augmentation(self, idx):\n"," src, tgt = self.__getitem__(idx)\n","\n"," src_aug, tgt_aug = self.augment_volume(src, tgt)\n"," \n"," if self.deform_augment:\n"," src_aug, tgt_aug = self.deform_volume(src_aug, tgt_aug)\n","\n"," return src_aug, tgt_aug \n","\n","# Define custom loss and dice coefficient\n","def dice_coefficient(y_true, y_pred):\n"," eps = 1e-6\n"," y_true_f = K.flatten(y_true)\n"," y_pred_f = K.flatten(y_pred)\n"," intersection = K.sum(y_true_f*y_pred_f)\n","\n"," return (2.*intersection)/(K.sum(y_true_f*y_true_f)+K.sum(y_pred_f*y_pred_f)+eps)\n","\n","def weighted_binary_crossentropy(zero_weight, one_weight):\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = K.binary_crossentropy(y_true, y_pred)\n","\n"," weight_vector = y_true*one_weight+(1.-y_true)*zero_weight\n"," weighted_binary_crossentropy = weight_vector*binary_crossentropy\n","\n"," return K.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","# Custom callback showing sample prediction\n","class SampleImageCallback(Callback):\n","\n"," def __init__(self, model, sample_data, model_path, save=False):\n"," self.model = model\n"," self.sample_data = sample_data\n"," self.model_path = model_path\n"," self.save = save\n","\n"," def on_epoch_end(self, epoch, logs={}):\n"," sample_predict = self.model.predict_on_batch(self.sample_data)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(self.sample_data[0,:,:,0,0], interpolation='nearest', cmap='gray')\n"," plt.title('Sample source')\n"," plt.axis('off');\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(sample_predict[0,:,:,0,0], interpolation='nearest', cmap='magma')\n"," plt.title('Predicted target')\n"," plt.axis('off');\n","\n"," plt.show()\n","\n"," if self.save:\n"," plt.savefig(self.model_path + '/epoch_' + str(epoch+1) + '.png')\n","\n","\n","# Define Unet3D class\n","class Unet3D:\n","\n"," def __init__(self,\n"," shape=(256,256,16,1)):\n"," if isinstance(shape, str):\n"," shape = eval(shape)\n","\n"," self.shape = shape\n"," \n"," input_tensor = Input(self.shape, name='input')\n","\n"," self.model = self.unet_3D(input_tensor)\n","\n"," def down_block_3D(self, input_tensor, filters):\n"," x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(input_tensor)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," return x\n","\n"," def up_block_3D(self, input_tensor, concat_layer, filters):\n"," x = Conv3DTranspose(filters, kernel_size=(2,2,2), strides=(2,2,2))(input_tensor)\n","\n"," x = Concatenate()([x, concat_layer])\n","\n"," x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(x)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," return x\n","\n"," def unet_3D(self, input_tensor, filters=32):\n"," d1 = self.down_block_3D(input_tensor, filters=filters)\n"," p1 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d1)\n"," d2 = self.down_block_3D(p1, filters=filters*2)\n"," p2 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d2)\n"," d3 = self.down_block_3D(p2, filters=filters*4)\n"," p3 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d3)\n","\n"," d4 = self.down_block_3D(p3, filters=filters*8)\n","\n"," u1 = self.up_block_3D(d4, d3, filters=filters*4)\n"," u2 = self.up_block_3D(u1, d2, filters=filters*2)\n"," u3 = self.up_block_3D(u2, d1, filters=filters)\n","\n"," output_tensor = Conv3D(filters=1, kernel_size=(1,1,1), activation='sigmoid')(u3)\n","\n"," return Model(inputs=[input_tensor], outputs=[output_tensor])\n","\n"," def summary(self):\n"," return self.model.summary()\n","\n"," # Pass generators instead\n"," def train(self, \n"," epochs, \n"," batch_size, \n"," train_generator,\n"," val_generator, \n"," model_path, \n"," model_name,\n"," optimizer='adam',\n"," loss='weighted_binary_crossentropy',\n"," metrics='dice',\n"," ckpt_period=1, \n"," save_best_ckpt_only=False, \n"," ckpt_path=None):\n","\n"," class_weight_zero, class_weight_one = train_generator.class_weights()\n"," \n"," if loss == 'weighted_binary_crossentropy':\n"," loss = weighted_binary_crossentropy(class_weight_zero, class_weight_one)\n"," \n"," if metrics == 'dice':\n"," metrics = dice_coefficient\n","\n"," self.model.compile(optimizer=optimizer,\n"," loss=loss,\n"," metrics=[metrics])\n","\n"," if ckpt_path is not None:\n"," self.model.load_weights(ckpt_path)\n","\n"," full_model_path = os.path.join(model_path, model_name)\n","\n"," if not os.path.exists(full_model_path):\n"," os.makedirs(full_model_path)\n"," \n"," log_dir = full_model_path + '/Quality Control'\n","\n"," if not os.path.exists(log_dir):\n"," os.makedirs(log_dir)\n"," \n"," ckpt_dir = full_model_path + '/ckpt'\n","\n"," if not os.path.exists(ckpt_dir):\n"," os.makedirs(ckpt_dir)\n","\n"," csv_out_name = log_dir + '/training_evaluation.csv'\n"," if ckpt_path is None:\n"," csv_logger = CSVLogger(csv_out_name)\n"," else:\n"," csv_logger = CSVLogger(csv_out_name, append=True)\n","\n"," if save_best_ckpt_only:\n"," ckpt_name = ckpt_dir + '/' + model_name + '.hdf5'\n"," else:\n"," ckpt_name = ckpt_dir + '/' + model_name + '_epoch_{epoch:02d}_val_loss_{val_loss:.4f}.hdf5'\n"," \n"," model_ckpt = ModelCheckpoint(ckpt_name,\n"," verbose=1,\n"," period=ckpt_period,\n"," save_best_only=save_best_ckpt_only,\n"," save_weights_only=True)\n","\n"," sample_batch, __ = val_generator.__getitem__(random.randint(0, len(val_generator)))\n"," sample_img = SampleImageCallback(self.model, \n"," sample_batch, \n"," model_path)\n","\n"," self.model.fit_generator(generator=train_generator,\n"," validation_data=val_generator,\n"," validation_steps=math.floor(len(val_generator)/batch_size),\n"," epochs=epochs,\n"," callbacks=[csv_logger,\n"," model_ckpt,\n"," sample_img])\n","\n"," last_ckpt_name = ckpt_dir + '/' + model_name + '_last.hdf5'\n"," self.model.save_weights(last_ckpt_name)\n","\n"," def _min_max_scaling(self, data):\n"," n = data - np.min(data)\n"," d = np.max(data) - np.min(data) \n"," \n"," return n/d\n","\n"," def predict(self, \n"," input, \n"," ckpt_path, \n"," z_range=None, \n"," downscaling=None, \n"," true_patch_size=None):\n","\n"," self.model.load_weights(ckpt_path)\n","\n"," if isinstance(downscaling, str):\n"," downscaling = eval(downscaling)\n","\n"," if math.isnan(downscaling):\n"," downscaling = None\n","\n"," if isinstance(true_patch_size, str):\n"," true_patch_size = eval(true_patch_size)\n"," \n"," if not isinstance(true_patch_size, tuple): \n"," if math.isnan(true_patch_size):\n"," true_patch_size = None\n","\n"," if isinstance(input, str):\n"," src_volume = tifffile.imread(input)\n"," elif isinstance(input, np.ndarray):\n"," src_volume = input\n"," else:\n"," raise TypeError('Input is not path or numpy array!')\n"," \n"," in_size = src_volume.shape\n","\n"," if downscaling or true_patch_size is not None:\n"," x_scaling = 0\n"," y_scaling = 0\n","\n"," if true_patch_size is not None:\n"," x_scaling += true_patch_size[0]/self.shape[0]\n"," y_scaling += true_patch_size[1]/self.shape[1]\n"," if downscaling is not None:\n"," x_scaling += downscaling\n"," y_scaling += downscaling\n","\n"," src_list = []\n"," for i in range(src_volume.shape[0]):\n"," src_list.append(transform.downscale_local_mean(src_volume[i], (int(x_scaling), int(y_scaling))))\n"," src_volume = np.array(src_list) \n","\n"," if z_range is not None:\n"," src_volume = src_volume[z_range[0]:z_range[1]]\n","\n"," src_volume = self._min_max_scaling(src_volume) \n","\n"," src_array = np.zeros((1,\n"," math.ceil(src_volume.shape[1]/self.shape[0])*self.shape[0], \n"," math.ceil(src_volume.shape[2]/self.shape[1])*self.shape[1],\n"," math.ceil(src_volume.shape[0]/self.shape[2])*self.shape[2], \n"," self.shape[3]))\n","\n"," for i in range(src_volume.shape[0]):\n"," src_array[0,:src_volume.shape[1],:src_volume.shape[2],i,0] = src_volume[i]\n","\n"," pred_array = np.empty(src_array.shape)\n","\n"," for i in range(math.ceil(src_volume.shape[1]/self.shape[0])):\n"," for j in range(math.ceil(src_volume.shape[2]/self.shape[1])):\n"," for k in range(math.ceil(src_volume.shape[0]/self.shape[2])):\n"," pred_temp = self.model.predict(src_array[:,\n"," i*self.shape[0]:i*self.shape[0]+self.shape[0],\n"," j*self.shape[1]:j*self.shape[1]+self.shape[1],\n"," k*self.shape[2]:k*self.shape[2]+self.shape[2]])\n"," pred_array[:,\n"," i*self.shape[0]:i*self.shape[0]+self.shape[0],\n"," j*self.shape[1]:j*self.shape[1]+self.shape[1],\n"," k*self.shape[2]:k*self.shape[2]+self.shape[2]] = pred_temp\n"," \n"," pred_volume = np.rollaxis(np.squeeze(pred_array), -1)[:src_volume.shape[0],:src_volume.shape[1],:src_volume.shape[2]] \n","\n"," if downscaling is not None:\n"," pred_list = []\n"," for i in range(pred_volume.shape[0]):\n"," pred_list.append(transform.resize(pred_volume[i], (in_size[1], in_size[2]), preserve_range=True))\n"," pred_volume = np.array(pred_list)\n","\n"," return pred_volume\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'U-Net 3D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," if os.path.isdir(training_source):\n"," shape = io.imread(training_source+'/'+os.listdir(training_source)[0]).shape\n"," elif os.path.isfile(training_source):\n"," shape = io.imread(training_source).shape\n"," else:\n"," print('Cannot read training data.')\n","\n"," dataset_size = len(train_generator)\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch_size: '+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if add_gaussian_blur == True:\n"," aug_text = aug_text+'\\n- gaussian blur'\n"," if add_linear_contrast == True:\n"," aug_text = aug_text+'\\n- linear contrast'\n"," if add_additive_gaussian_noise == True:\n"," aug_text = aug_text+'\\n- additive gaussian noise'\n"," if augmenters != '':\n"," aug_text = aug_text+'\\n- imgaug augmentations: '+augmenters\n"," if add_elastic_deform == True:\n"," aug_text = aug_text+'\\n- elastic deformation'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if use_default_advanced_parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
batch_size{1}
patch_size{2}
image_pre_processing{3}
validation_split_in_percent{4}
downscaling_in_xy{5}
binary_target{6}
loss_function{7}
metrics{8}
optimizer{9}
checkpointing_period{10}
save_best_only{11}
resume_training{12}
\n"," \"\"\".format(number_of_epochs,batch_size,str(patch_size[0])+'x'+str(patch_size[1])+'x'+str(patch_size[2]),image_pre_processing, validation_split_in_percent, downscaling_in_xy, str(binary_target), loss_function, metrics, optimizer, checkpointing_period, str(save_best_only), str(resume_training))\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Unet3D.png').shape\n"," pdf.image('/content/TrainingDataExample_Unet3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_4, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'U-Net 3D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+qc_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," if os.path.exists(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png'):\n"," exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png').shape\n"," pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png').shape\n"," pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'IoU threshold optimisation', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.ln(1)\n"," pdf.cell(120, 5, txt='Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh), align='L', ln=1)\n"," pdf.ln(2)\n"," exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png').shape\n"," pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png', x=16, y=None, w = round(exp_size[1]/6), h = round(exp_size[0]/6))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/'+qc_model_name+'_QC_report.pdf')\n","\n"," print('------------------------------')\n"," print('QC PDF report exported in '+os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/')\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net 3D and dependencies installed.')\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n"," \n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","# Exporting requirements.txt for local run\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":["## **3.1. Choosing parameters**\n","\n","---\n","\n","### **Paths to training data and model**\n","\n","* **`training_source`** and **`training_target`** specify the paths to the training data. They can either be a single multipage TIFF file each or directories containing various multipage TIFF files in which case target and source files must be named identically within the respective directories. See Section 0 for a detailed description of the necessary directory structure.\n","\n","* **`model_name`** will be used when naming checkpoints. Adhere to a `lower_case_with_underscores` naming convention and beware of using the name of an existing model within the same folder, as it will be overwritten.\n","\n","* **`model_path`** specifies the directory where the model checkpoints and quality control logs will be saved.\n","\n","\n","**Note:** You can copy paths from the 'Files' tab by right-clicking any folder or file and selecting 'Copy path'. \n","\n","### **Training parameters**\n","\n","* **`number_of_epochs`** is the number of times the entire training data will be seen by the model. *Default: >100*\n","\n","* **`batch_size`** is the number of training patches of size `patch_size` that will be bundled together at each training step. *Default: 1*\n","\n","* **`patch_size`** specifies the size of the three-dimensional training patches in (x, y, z) that will be fed to the model. In order to avoid errors, preferably use a square aspect ratio or stick to the advanced parameters. *Default: <(512, 512, 16)*\n","\n","* **`validation_split_in_percent`** is the relative amount of training data that will be set aside for validation. *Default: 20* \n","\n","* **`downscaling_in_xy`** downscales the training images by the specified amount in x and y. This is useful to enforce isotropic pixel-size if the z resolution is lower than the xy resolution in the training volume or to capture a larger field-of-view while decreasing the memory requirements. *Default: 1*\n","\n","* **`image_pre_processing`** selects whether the training images are randomly cropped during training or resized to `patch_size`. Choose `randomly crop to patch_size` to shrink the field-of-view of the training images to the `patch_size`. *Default: resize to patch_size* \n","\n","* **`binary_target`** forces the target image to be binary. Choose this if your model is trained to perform binary segmentation tasks *Default: True* \n","\n","* **`loss_function`** defines the loss. Read more [here](https://keras.io/api/losses/). *Default: weighted_binary_crossentropy* \n","\n","* **`metrics`** defines the metric. Read more [here](https://keras.io/api/metrics/). *Default: dice* \n","\n","* **`optimizer`** defines the optimizer. Read more [here](https://keras.io/api/optimizers/). *Default: adam* \n","\n","**Note:** If a *ResourceExhaustedError* is raised in Section 4.1. during training, decrease `batch_size` and `patch_size`. Decrease `batch_size` first and if the error persists at `batch_size = 1`, reduce the `patch_size`. \n","\n","**Note:** The number of steps per epoch are calculated as `floor(augment_factor * (1 - validation_split) * num_of_slices / batch_size)` if `image_pre_processing` is `resize to patch_size` where `augment_factor` is three if `apply_data_augmentation` is `True` and one otherwise. The `num_of_slices` is the overall number of slices (z-depth) in the training set across all provided image volumes. If `image_pre_processing` is `randomly crop to patch_size`, the number of steps per epoch are calculated as `floor(augment_factor * volume / (crop_volume * batch_size))` where `volume` is the overall volume of the training data in pixels accounting for the validation split and `crop_volume` is defined as the volume in pixels based on the specified `patch_size`."]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training data:\n","training_source = \"\" #@param {type:\"string\"}\n","training_target = \"\" #@param {type:\"string\"}\n","\n","#@markdown ---\n","\n","#@markdown ###Model name and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","full_model_path = os.path.join(model_path, model_name)\n","\n","#@markdown ---\n","\n","#@markdown ###Training parameters\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown ###Default advanced parameters\n","use_default_advanced_parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown If not, please change:\n","\n","batch_size = 1#@param {type:\"number\"}\n","patch_size = (256,256,4) #@param {type:\"number\"} # in pixels\n","training_shape = patch_size + (1,)\n","image_pre_processing = 'randomly crop to patch_size' #@param [\"randomly crop to patch_size\", \"resize to patch_size\"]\n","\n","validation_split_in_percent = 20 #@param{type:\"number\"}\n","downscaling_in_xy = 2#@param {type:\"number\"} # in pixels\n","\n","binary_target = True #@param {type:\"boolean\"}\n","\n","loss_function = 'weighted_binary_crossentropy' #@param [\"weighted_binary_crossentropy\", \"binary_crossentropy\", \"categorical_crossentropy\", \"sparse_categorical_crossentropy\", \"mean_squared_error\", \"mean_absolute_error\"]\n","\n","metrics = 'dice' #@param [\"dice\", \"accuracy\"]\n","\n","optimizer = 'adam' #@param [\"adam\", \"sgd\", \"rmsprop\"]\n","\n","\n","if image_pre_processing == \"randomly crop to patch_size\":\n"," random_crop = True\n","else:\n"," random_crop = False\n","\n","if use_default_advanced_parameters: \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," training_shape = (256,256,8,1)\n"," validation_split_in_percent = 20\n"," downscaling_in_xy = 1\n"," random_crop = True\n"," binary_target = True\n"," loss_function = 'weighted_binary_crossentropy'\n"," metrics = 'dice'\n"," optimizer = 'adam'\n","\n","#@markdown ###Checkpointing parameters\n","checkpointing_period = 1 #@param {type:\"number\"}\n","\n","#@markdown If chosen only the best checkpoint is saved, otherwise a checkpoint is saved every checkpoint_period epochs:\n","save_best_only = True #@param {type:\"boolean\"}\n","\n","#@markdown ###Resume training\n","#@markdown Choose if training was interrupted:\n","resume_training = False #@param {type:\"boolean\"}\n","\n","#@markdown ###Transfer learning\n","#@markdown For transfer learning, do not select resume_training and specify a checkpoint_path below:\n","checkpoint_path = \"\" #@param {type:\"string\"}\n","\n","if resume_training and checkpoint_path != \"\":\n"," print('If resume_training is True while checkpoint_path is specified, resume_training will be set to False!')\n"," resume_training = False\n"," \n","\n","# Retrieve last checkpoint\n","if resume_training:\n"," try:\n"," ckpt_dir_list = glob(full_model_path + '/ckpt/*')\n"," ckpt_dir_list.sort()\n"," last_ckpt_path = ckpt_dir_list[-1]\n"," print('Training will resume from checkpoint:', os.path.basename(last_ckpt_path))\n"," except IndexError:\n"," last_ckpt_path=None\n"," print('CheckpointError: No previous checkpoints were found, training from scratch.')\n","elif not resume_training and checkpoint_path != \"\":\n"," last_ckpt_path = checkpoint_path\n"," assert os.path.isfile(last_ckpt_path), 'checkpoint_path does not exist!'\n","else:\n"," last_ckpt_path=None\n","\n","# Instantiate Unet3D \n","model = Unet3D(shape=training_shape)\n","\n","#here we check that no model with the same name already exist\n","if not resume_training and os.path.exists(full_model_path): \n"," print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n"," # print('!! WARNING: Folder already exists and will be overwritten !!') \n"," # shutil.rmtree(full_model_path)\n","\n","# if not os.path.exists(full_model_path):\n","# os.makedirs(full_model_path)\n","\n","# Show sample image\n","if os.path.isdir(training_source):\n"," training_source_sample = sorted(glob(os.path.join(training_source, '*')))[0]\n"," training_target_sample = sorted(glob(os.path.join(training_target, '*')))[0]\n","else:\n"," training_source_sample = training_source\n"," training_target_sample = training_target\n","\n","src_sample = tifffile.imread(training_source_sample)\n","src_sample = model._min_max_scaling(src_sample)\n","if binary_target:\n"," tgt_sample = tifffile.imread(training_target_sample).astype(np.bool)\n","else:\n"," tgt_sample = tifffile.imread(training_target_sample)\n","\n","src_down = transform.downscale_local_mean(src_sample[0], (downscaling_in_xy, downscaling_in_xy))\n","tgt_down = transform.downscale_local_mean(tgt_sample[0], (downscaling_in_xy, downscaling_in_xy)) \n","\n","if random_crop:\n"," true_patch_size = None\n","\n"," if src_down.shape[0] == training_shape[0]:\n"," x_rand = 0\n"," if src_down.shape[1] == training_shape[1]:\n"," y_rand = 0\n"," if src_down.shape[0] > training_shape[0]:\n"," x_rand = np.random.randint(src_down.shape[0] - training_shape[0])\n"," if src_down.shape[1] > training_shape[1]:\n"," y_rand = np.random.randint(src_down.shape[1] - training_shape[1])\n"," if src_down.shape[0] < training_shape[0] or src_down.shape[1] < training_shape[1]:\n"," raise ValueError('Patch shape larger than (downscaled) source shape')\n","else:\n"," true_patch_size = src_down.shape\n","\n","def scroll_in_z(z):\n"," src_down = transform.downscale_local_mean(src_sample[z-1], (downscaling_in_xy,downscaling_in_xy))\n"," tgt_down = transform.downscale_local_mean(tgt_sample[z-1], (downscaling_in_xy,downscaling_in_xy)) \n"," if random_crop:\n"," src_slice = src_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n"," tgt_slice = tgt_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n"," else:\n"," \n"," src_slice = transform.resize(src_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n"," tgt_slice = transform.resize(tgt_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(src_slice, cmap='gray')\n"," plt.title('Training source (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(tgt_slice, cmap='magma')\n"," plt.title('Training target (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n"," plt.savefig('/content/TrainingDataExample_Unet3D.png',bbox_inches='tight',pad_inches=0)\n"," #plt.close()\n","\n","print('This is what the training images will look like with the chosen settings')\n","interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_sample.shape[0], step=1, value=0));\n","\n","#Create a copy of an example slice and close the display.\n","scroll_in_z(z=int(src_sample.shape[0]/2))\n","plt.close()\n","\n","# Save model parameters\n","params = {'training_source': training_source,\n"," 'training_target': training_target,\n"," 'model_name': model_name,\n"," 'model_path': model_path,\n"," 'number_of_epochs': number_of_epochs,\n"," 'batch_size': batch_size,\n"," 'training_shape': training_shape,\n"," 'downscaling': downscaling_in_xy,\n"," 'true_patch_size': true_patch_size,\n"," 'val_split': validation_split_in_percent/100,\n"," 'random_crop': random_crop}\n","\n","params_df = pd.DataFrame.from_dict(params, orient='index')\n","\n","# apply_data_augmentation = False\n","# pdf_export(augmentation = apply_data_augmentation, pretrained_model = resume_training)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["## **3.2. Data augmentation**\n"," \n","---\n"," Augmenting the training data increases robustness of the model by simulating possible variations within the training data which avoids it from overfitting on small datasets. We therefore strongly recommended augmenting the data and making sure that the applied augmentations are reasonable.\n","\n","* **Gaussian blur** blurs images using Gaussian kernels with a sigma of `gaussian_sigma`. This augmentation step is applied with a probability of `gaussian_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/blur.html#gaussianblur).\n","\n","* **Linear contrast** modifies the contrast of images according to `127 + alpha *(pixel_value-127)`, where `pixel_value` and `alpha` are sampled uniformly from the interval `[contrast_min, contrast_max]`. This augmentation step is applied with a probability of `contrast_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/contrast.html#linearcontrast).\n","\n","* **Additive Gaussian noise** adds Gaussian noise sampled once per pixel from a normal distribution `N(0, s)`, where `s` is sampled from `[scale_min, scale_max]`. This augmentation step is applied with a probability of `noise_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/arithmetic.html#additivegaussiannoise).\n","\n","* **Add custom augmenters** allows you to create a custom augmentation pipeline using the [augmenters available in the imagug library](https://imgaug.readthedocs.io/en/latest/source/overview_of_augmenters.html).\n","In the example above, the augmentation pipeline is equivalent to: \n","```\n","seq = iaa.Sequential([\n"," iaa.Sometimes(0.3, iaa.GammaContrast((0.5, 2.0)), \n"," iaa.Sometimes(0.4, iaa.AverageBlur((0.5, 2.0)), \n"," iaa.Sometimes(0.5, iaa.LinearContrast((0.4, 1.6)), \n","], random_order=True)\n","```\n"," Note that there is no limit on the number of augmenters that can be chained together and that individual augmenter and parameter entries must be separated by `;`. Custom augmenters do not overwrite the preset augmentation steps (*Gaussian blur*, *Linear contrast* or *Additive Gaussian noise*). Also, the augmenters, augmenter parameters and augmenter frequencies must be entered such that each position within the string corresponds to the same augmentation step.\n","\n","* **`apply_data_augmentation`** ensures that data augmentation is randomly applied to the training data at each training step. This includes inverting the order of the slices within a training patch, as well as applying any augmenters that are added. *Default: True*\n","\n","* **`add_elastic_deform`** ensures that elastic grid-based deformations are applied as described in the original 3D U-Net paper. *Default: True*"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#@markdown ##**Augmentation options**\n","\n","#@markdown ###Data augmentation\n","\n","apply_data_augmentation = False #@param {type:\"boolean\"}\n","\n","# List of augmentations\n","augmentations = []\n","\n","#@markdown ###Gaussian blur\n","add_gaussian_blur = True #@param {type:\"boolean\"}\n","gaussian_sigma = 0.7#@param {type:\"number\"}\n","gaussian_frequency = 0.5 #@param {type:\"number\"}\n","\n","if add_gaussian_blur:\n"," augmentations.append(iaa.Sometimes(gaussian_frequency, iaa.GaussianBlur(sigma=(0, gaussian_sigma))))\n","\n","#@markdown ###Linear contrast\n","add_linear_contrast = True #@param {type:\"boolean\"}\n","contrast_min = 0.4 #@param {type:\"number\"}\n","contrast_max = 1.6#@param {type:\"number\"}\n","contrast_frequency = 0.5 #@param {type:\"number\"}\n","\n","if add_linear_contrast:\n"," augmentations.append(iaa.Sometimes(contrast_frequency, iaa.LinearContrast((contrast_min, contrast_max))))\n","\n","#@markdown ###Additive Gaussian noise\n","add_additive_gaussian_noise = False #@param {type:\"boolean\"}\n","scale_min = 0 #@param {type:\"number\"}\n","scale_max = 0.05 #@param {type:\"number\"}\n","noise_frequency = 0.5 #@param {type:\"number\"}\n","\n","if add_additive_gaussian_noise:\n"," augmentations.append(iaa.Sometimes(noise_frequency, iaa.AdditiveGaussianNoise(scale=(scale_min, scale_max))))\n","\n","#@markdown ###Add custom augmenters\n","\n","augmenters = \"GammaContrast; AverageBlur; LinearContrast\" #@param {type:\"string\"}\n","\n","augmenter_params = \"(0.5, 2.0); (0.5, 2.0); (0.4, 1.6)\" #@param {type:\"string\"}\n","\n","augmenter_frequency = \"0.3; 0.4; 0.5\" #@param {type:\"string\"}\n","\n","aug_lst = augmenters.split(';')\n","aug_params_lst = augmenter_params.split(';')\n","aug_freq_lst = augmenter_frequency.split(';')\n","\n","assert len(aug_lst) == len(aug_params_lst) and len(aug_lst) == len(aug_freq_lst), 'The number of arguments in augmenters, augmenter_params and augmenter_frequency are not the same!'\n","\n","for __, (aug, param, freq) in enumerate(zip(aug_lst, aug_params_lst, aug_freq_lst)):\n"," aug, param, freq = aug.strip(), param.strip(), freq.strip() \n"," aug_func = iaa.Sometimes(eval(freq), getattr(iaa, aug)(eval(param)))\n"," augmentations.append(aug_func)\n","\n","#@markdown ###Elastic deformations\n","add_elastic_deform = True #@param {type:\"boolean\"}\n","sigma = 2#@param {type:\"number\"}\n","points = 2#@param {type:\"number\"}\n","order = 2#@param {type:\"number\"}\n","\n","if add_elastic_deform:\n"," deform_params = (sigma, points, order)\n","else:\n"," deform_params = None\n","\n","train_generator = MultiPageTiffGenerator(training_source,\n"," training_target,\n"," batch_size=batch_size,\n"," shape=training_shape,\n"," augment=apply_data_augmentation,\n"," augmentations=augmentations,\n"," deform_augment=add_elastic_deform,\n"," deform_augmentation_params=deform_params,\n"," val_split=validation_split_in_percent/100,\n"," random_crop=random_crop,\n"," downscale=downscaling_in_xy,\n"," binary_target=binary_target)\n","\n","val_generator = MultiPageTiffGenerator(training_source,\n"," training_target,\n"," batch_size=batch_size,\n"," shape=training_shape,\n"," val_split=validation_split_in_percent/100,\n"," is_val=True,\n"," random_crop=random_crop,\n"," downscale=downscaling_in_xy,\n"," binary_target=binary_target)\n","\n","\n","if apply_data_augmentation:\n"," print('Data augmentation enabled.')\n"," sample_src_aug, sample_tgt_aug = train_generator.sample_augmentation(random.randint(0, len(train_generator)))\n","\n"," def scroll_in_z(z):\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(sample_src_aug[0,:,:,z-1,0], cmap='gray')\n"," plt.title('Sample augmented source (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(sample_tgt_aug[0,:,:,z-1,0], cmap='magma')\n"," plt.title('Sample training target (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," print('This is what the augmented training images will look like with the chosen settings')\n"," interact(scroll_in_z, z=widgets.IntSlider(min=1, max=sample_src_aug.shape[3], step=1, value=0));\n","\n","else:\n"," print('Data augmentation disabled.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["# **4. Train the network**\n","---\n","\n","**CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training times must be less than 12 hours! If training takes longer than 12 hours, please decrease `number_of_epochs`."]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Show model and start training**\n","---\n"]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ## Show model summary\n","model.summary()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"CyQI4ssarUp4"},"source":["#@markdown ##Start training\n","\n","#here we check that no model with the same name already exist, if so delete\n","if not resume_training and os.path.exists(full_model_path): \n"," shutil.rmtree(full_model_path)\n"," print(bcolors.WARNING+'!! WARNING: Folder already exists and has been overwritten !!'+bcolors.NORMAL) \n","\n","if not os.path.exists(full_model_path):\n"," os.makedirs(full_model_path)\n","\n","pdf_export(augmentation = apply_data_augmentation, pretrained_model = resume_training)\n","\n","# Save file\n","params_df.to_csv(os.path.join(full_model_path, 'params.csv'))\n","\n","start = time.time()\n","# Start Training\n","model.train(epochs=number_of_epochs,\n"," batch_size=batch_size,\n"," train_generator=train_generator,\n"," val_generator=val_generator,\n"," model_path=model_path,\n"," model_name=model_name,\n"," loss=loss_function,\n"," metrics=metrics,\n"," optimizer=optimizer,\n"," ckpt_period=checkpointing_period,\n"," save_best_ckpt_only=save_best_only,\n"," ckpt_path=last_ckpt_path)\n","\n","print('Training successfully completed!')\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = apply_data_augmentation, pretrained_model = resume_training)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["##**4.2. Download your model from Google Drive**\n","\n","---\n","Once training is complete, the trained model is automatically saved to your Google Drive, in the **`model_path`** folder that was specified in Section 3. Download the folder to avoid any unwanted surprises, since the data can be erased if you train another model using the same `model_path`."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Download model directory\n","#@markdown 1. Specify the model_path in `model_path_download` otherwise the model sepcified in Section 3.1 will be downloaded\n","#@markdown 2. Run this cell to zip the model directory\n","#@markdown 3. Download the zipped file from the *Files* tab on the left\n","\n","from google.colab import files\n","\n","model_path_download = \"\" #@param {type:\"string\"}\n","\n","if len(model_path_download) == 0:\n"," model_path_download = full_model_path\n","\n","model_name_download = os.path.basename(model_path_download)\n","\n","print('Zipping', model_name_download)\n","\n","zip_model_path = model_name_download + '.zip'\n","\n","!zip -r \"$zip_model_path\" \"$model_path_download\"\n","\n","print('Successfully saved zipped model directory as', zip_model_path)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","In this section the newly trained model can be assessed for performance. This involves inspecting the loss function in Section 5.1. and employing more advanced metrics in Section 5.2.\n","\n","**We highly recommend performing quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["#@markdown ###Model to be evaluated:\n","#@markdown If left blank, the latest model defined in Section 3 will be evaluated:\n","\n","qc_model_name = \"\" #@param {type:\"string\"}\n","qc_model_path = \"\" #@param {type:\"string\"}\n","\n","if len(qc_model_path) == 0 and len(qc_model_name) == 0:\n"," qc_model_name = model_name\n"," qc_model_path = model_path\n","\n","full_qc_model_path = os.path.join(qc_model_path, qc_model_name)\n","\n","if os.path.exists(full_qc_model_path):\n"," print(qc_model_name + ' will be evaluated')\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspecting loss function**\n","---\n","\n","**The training loss** is the error between prediction and target after each epoch calculated across the training data while the **validation loss** calculates the error on the (unseen) validation data. During training these values should decrease until converging at which point the model has been sufficiently trained. If the validation loss starts increasing while the training loss has plateaued, the model has overfit on the training data which reduces its ability to generalise. Aim to halt training before this point.\n","\n","**Note:** For a more in-depth explanation please refer to [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols et al.\n","\n","\n","The accuracy is another performance metric that is calculated after each epoch. We use the [Sørensen–Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) to score the prediction accuracy. \n","\n"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Visualise loss and accuracy\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","accuracyDataFromCSV = []\n","valaccuracyDataFromCSV = []\n","\n","with open(full_qc_model_path + '/Quality Control/training_evaluation.csv', 'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[2]))\n"," vallossDataFromCSV.append(float(row[4]))\n"," accuracyDataFromCSV.append(float(row[1]))\n"," valaccuracyDataFromCSV.append(float(row[3]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training and validation loss', fontsize=14)\n","plt.ylabel('Loss', fontsize=12)\n","plt.xlabel('Epochs', fontsize=12)\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.plot(epochNumber,accuracyDataFromCSV, label='Training accuracy')\n","plt.plot(epochNumber,valaccuracyDataFromCSV, label='Validation accuracy')\n","plt.title('Training and validation accuracy', fontsize=14)\n","plt.ylabel('Dice', fontsize=12)\n","plt.xlabel('Epochs', fontsize=12)\n","plt.legend()\n","plt.savefig(full_qc_model_path + '/Quality Control/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will provide both a visual indication of the model performance by comparing the overlay of the predicted and source volume."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Compare prediction and ground-truth on testing data\n","\n","#@markdown Provide an unseen annotated dataset to determine the performance of the model:\n","\n","testing_source = \"\" #@param{type:\"string\"}\n","testing_target = \"\" #@param{type:\"string\"}\n","\n","qc_dir = full_qc_model_path + '/Quality Control'\n","predict_dir = qc_dir + '/Prediction'\n","if os.path.exists(predict_dir):\n"," shutil.rmtree(predict_dir)\n","\n","os.makedirs(predict_dir)\n","\n","# predict_dir + '/' + \n","predict_path = os.path.splitext(os.path.basename(testing_source))[0] + '_prediction.tif'\n","\n","def last_chars(x):\n"," return(x[-11:])\n","\n","try:\n"," ckpt_dir_list = glob(full_qc_model_path + '/ckpt/*')\n"," ckpt_dir_list.sort(key=last_chars)\n"," last_ckpt_path = ckpt_dir_list[0]\n"," print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n","except IndexError:\n"," raise CheckpointError('No previous checkpoints were found, please retrain model.')\n","\n","# Load parameters\n","params = pd.read_csv(os.path.join(full_qc_model_path, 'params.csv'), names=['val'], header=0, index_col=0) \n","\n","model = Unet3D(shape=params.loc['training_shape', 'val'])\n","\n","prediction = model.predict(testing_source, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n","\n","tifffile.imwrite(predict_path, prediction.astype('float32'), imagej=True)\n","\n","print('Predicted images!')\n","\n","qc_metrics_path = full_qc_model_path + '/Quality Control/QC_metrics_' + qc_model_name + '.csv'\n","\n","test_target = tifffile.imread(testing_target)\n","test_source = tifffile.imread(testing_source)\n","test_prediction = tifffile.imread(predict_path)\n","\n","def scroll_in_z(z):\n","\n"," plt.figure(figsize=(25,5))\n"," # Source\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(test_source[z-1], cmap='gray')\n"," plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n","\n"," # Target (Ground-truth)\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(test_target[z-1], cmap='magma')\n"," plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n","\n"," # Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(test_prediction[z-1], cmap='magma')\n"," plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n"," \n"," # Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_target[z-1], cmap='Greens')\n"," plt.imshow(test_prediction[z-1], alpha=0.5, cmap='Purples')\n"," plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n"," plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n","interact(scroll_in_z, z=widgets.IntSlider(min=1, max=test_source.shape[0], step=1, value=0));"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"aIvRxpZlsFeZ"},"source":["## **5.3. Determine best Intersection over Union and threshold**\n","---\n","\n","**Note:** This section is only relevant if the target image is a binary mask and `binary_target` is selected in Section 3! \n","\n","This section will provide both a visual and a quantitative indication of the model performance by comparing the overlay of the predicted and source volume, as well as computing the highest [**Intersection over Union**](https://en.wikipedia.org/wiki/Jaccard_index) (IoU) score. The IoU is also known as the Jaccard Index. \n","\n","The best threshold is calculated using the IoU. Each threshold value from 0 to 255 is tested and the threshold with the highest score is deemed the best. The IoU is calculated for the entire volume in 3D."]},{"cell_type":"code","metadata":{"cellView":"form","id":"XhkeZTFusHA8"},"source":["\n","#@markdown ##Calculate Intersection over Union and best threshold \n","prediction = tifffile.imread(predict_path)\n","prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n","\n","target = tifffile.imread(testing_target).astype(np.bool)\n","\n","def iou_vs_threshold(prediction, target):\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," mask = prediction > threshold\n","\n"," intersection = np.logical_and(target, mask)\n"," union = np.logical_or(target, mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return threshold_list, IoU_scores_list\n","\n","threshold_list, IoU_scores_list = iou_vs_threshold(prediction, target)\n","thresh_arr = np.array(list(zip(threshold_list, IoU_scores_list)))\n","best_thresh = int(np.where(thresh_arr == np.max(thresh_arr[:,1]))[0])\n","best_iou = IoU_scores_list[best_thresh]\n","\n","print('Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh))\n","\n","def adjust_threshold(threshold, z):\n","\n"," f=plt.figure(figsize=(25,5))\n"," plt.subplot(1,4,1)\n"," plt.imshow((prediction[z-1] > threshold).astype('uint8'), cmap='magma')\n"," plt.title('Prediction (Threshold = ' + str(threshold) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,4,2)\n"," plt.imshow(target[z-1], cmap='magma')\n"," plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(test_source[z-1], cmap='gray')\n"," plt.imshow((prediction[z-1] > threshold).astype('uint8'), alpha=0.4, cmap='Reds')\n"," plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n","\n"," plt.subplot(1,4,4)\n"," plt.title('Threshold vs. IoU', fontsize=15)\n"," plt.plot(threshold_list, IoU_scores_list)\n"," plt.plot(threshold, IoU_scores_list[threshold], 'ro') \n"," plt.ylabel('IoU score')\n"," plt.xlabel('Threshold')\n"," plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png',bbox_inches=matplotlib.transforms.Bbox([[17.5,0],[23,5]]),pad_inches=0)\n"," plt.show()\n","\n","interact(adjust_threshold, \n"," threshold=widgets.IntSlider(min=0, max=255, step=1, value=best_thresh),\n"," z=widgets.IntSlider(min=1, max=prediction.shape[0], step=1, value=0));\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","Once sufficient performance of the trained model has been established using Section 5, the network can be used to segment unseen volumetric data."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate predictions from unseen dataset**\n","---\n","\n","The most recently trained model can now be used to predict segmentation masks on unseen images. If you want to use an older model, leave `model_path` blank. Predicted output images are saved in `output_path` as Image-J compatible TIFF files.\n","\n","## **Prediction parameters**\n","\n","* **`source_path`** specifies the location of the source \n","image volume.\n","\n","* **`output_directory`** specified the directory where the output predictions are stored.\n","\n","* **`binary_target`** should be chosen if the network is trained to predict binary segmentation masks.\n","\n","* **`threshold`** can be calculated in Section 5 and is used to generate binary masks from the predictions.\n","\n","* **`big_tiff`** should be chosen if the expected prediction exceeds 4GB. The predictions will be saved using the BigTIFF format. Beware that this might substantially reduce the prediction speed. *Default: False* \n","\n","* **`prediction_depth`** is only relevant if the prediction is saved as a BigTIFF. The prediction will not be performed in one go to not deplete the memory resources. Instead, the prediction is iteratively performed on a subset of the entire volume with shape `(source.shape[0], source.shape[1], prediction_depth)`. *Default: 32*\n","\n","* **`model_path`** specifies the path to a model other than the most recently trained."]},{"cell_type":"code","metadata":{"cellView":"form","id":"DEmhPh5fsWX2"},"source":["#@markdown ## Download example volume\n","\n","#@markdown This can take up to an hour\n","\n","import requests \n","import os\n","from tqdm.notebook import tqdm \n","\n","\n","def download_from_url(url, save_as):\n"," file_url = url\n"," r = requests.get(file_url, stream=True) \n"," \n"," with open(save_as, 'wb') as file: \n"," for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=3275073, ncols=1000):\n"," if block:\n"," file.write(block) \n","\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/volumedata.tif', 'example_dataset/volumedata.tif')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then run the cell to predict outputs from your unseen images.\n","\n","source_path = \"\" #@param {type:\"string\"}\n","output_directory = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(output_directory):\n"," os.makedirs(output_directory)\n","\n","output_path = os.path.join(output_directory, os.path.splitext(os.path.basename(source_path))[0] + '_predicted.tif')\n","#@markdown ###Prediction parameters:\n","\n","binary_target = True #@param {type:\"boolean\"}\n","\n","save_probability_map = False #@param {type:\"boolean\"}\n","\n","#@markdown Determine best threshold in Section 5.2.\n","\n","use_calculated_threshold = True #@param {type:\"boolean\"}\n","threshold = 200#@param {type:\"number\"}\n","\n","# Tifffile library issues means that images cannot be appended to \n","#@markdown Choose if prediction file exceeds 4GB or if input file is very large (above 2GB). Image volume saved as BigTIFF.\n","big_tiff = False #@param {type:\"boolean\"}\n","\n","#@markdown Reduce `prediction_depth` if runtime runs out of memory during prediction. Only relevant if prediction saved as BigTIFF\n","\n","prediction_depth = 32#@param {type:\"number\"}\n","\n","#@markdown ###Model to be evaluated\n","#@markdown If left blank, the latest model defined in Section 5 will be evaluated\n","\n","full_model_path_ = \"\" #@param {type:\"string\"}\n","\n","if len(full_model_path_) == 0:\n"," full_model_path_ = os.path.join(qc_model_path, qc_model_name) \n","\n","\n","\n","# Load parameters\n","params = pd.read_csv(os.path.join(full_model_path_, 'params.csv'), names=['val'], header=0, index_col=0) \n","model = Unet3D(shape=params.loc['training_shape', 'val'])\n","\n","if use_calculated_threshold:\n"," threshold = best_thresh\n","\n","def last_chars(x):\n"," return(x[-11:])\n","\n","try:\n"," ckpt_dir_list = glob(full_model_path_ + '/ckpt/*')\n"," ckpt_dir_list.sort(key=last_chars)\n"," last_ckpt_path = ckpt_dir_list[0]\n"," print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n","except IndexError:\n"," raise CheckpointError('No previous checkpoints were found, please retrain model.')\n","\n","src = tifffile.imread(source_path)\n","\n","if src.nbytes >= 4e9:\n"," big_tiff = True\n"," print('The source file exceeds 4GB in memory, prediction will be saved as BigTIFF!')\n","\n","if binary_target:\n"," if not big_tiff:\n"," prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," prediction = (prediction > threshold).astype('float32')\n","\n"," tifffile.imwrite(output_path, prediction, imagej=True)\n","\n"," else:\n"," with tifffile.TiffWriter(output_path, bigtiff=True) as tif:\n"," for i in tqdm(range(0, src.shape[0], prediction_depth)):\n"," prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," prediction = (prediction > threshold).astype('float32')\n"," \n"," for j in range(prediction.shape[0]):\n"," tif.save(prediction[j])\n","\n","if not binary_target or save_probability_map:\n"," if not binary_target:\n"," prob_map_path = output_path\n"," else:\n"," prob_map_path = os.path.splitext(output_path)[0] + '_prob_map.tif'\n"," \n"," if not big_tiff:\n"," prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," tifffile.imwrite(prob_map_path, prediction.astype('float32'), imagej=True)\n","\n"," else:\n"," with tifffile.TiffWriter(prob_map_path, bigtiff=True) as tif:\n"," for i in tqdm(range(0, src.shape[0], prediction_depth)):\n"," prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," \n"," for j in range(prediction.shape[0]):\n"," tif.save(prediction[j])\n","\n","print('Predictions saved as', output_path)\n","\n","src_volume = tifffile.imread(source_path)\n","pred_volume = tifffile.imread(output_path)\n","\n","def scroll_in_z(z):\n"," \n"," f=plt.figure(figsize=(25,5))\n"," plt.subplot(1,2,1)\n"," plt.imshow(src_volume[z-1], cmap='gray')\n"," plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(pred_volume[z-1], cmap='magma')\n"," plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n","interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_volume.shape[0], step=1, value=0));\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using 3D U-Net!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb b/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb index 2417aae8..5a3d89b5 100644 --- a/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"YOLOv2_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"188vC6YW2QgihAAlmm9-z8Bs8VHNc4lSX","timestamp":1604947327316},{"file_id":"10kITYRT5xn39V5MW2H7SpBnngGOUwyNz","timestamp":1603468429406},{"file_id":"1Tb46UlrIGLb8rv74lOtpJk0sH4IQoMjB","timestamp":1602675839754},{"file_id":"1LWs9bFbYclR1nWaupcSPUYFN6yyUU_5t","timestamp":1596536407170},{"file_id":"1uUjR8Sm2l6vAJfclb84gUUH4MCwzQUWO","timestamp":1594734310956},{"file_id":"1zileODcR2RNrVSidXNuBfgFDv68JRRa0","timestamp":1593093410185},{"file_id":"1EpgWlJK6U_ZwlBGiomLfbxx9UUtRPBTy","timestamp":1592904104821},{"file_id":"1f5usS6p8Cu_efegMwcR3v68AVOXBSyIf","timestamp":1588870626184},{"file_id":"1fM7obTEQKnSgVZMDa1KjiBgiBar2b0t8","timestamp":1588693012611},{"file_id":"1owWtQQucUxUOZMaPh2x_mxe_qXKHCZhp","timestamp":1588074588514},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **YOLOv2**\n","---\n","\n"," YOLOv2 is a deep-learning method designed to perform object detection and classification of objects in images, published by [Redmon and Farhadi](https://ieeexplore.ieee.org/document/8100173). This is based on the original [YOLO](https://arxiv.org/abs/1506.02640) implementation published by the same authors. YOLOv2 is trained on images with class annotations in the form of bounding boxes drawn around the objects of interest. The images are downsampled by a convolutional neural network (CNN) and objects are classified in two final fully connected layers in the network. YOLOv2 learns classification and object detection simultaneously by taking the whole input image into account, predicting many possible bounding box solutions, and then using regression to find the best bounding boxes and classifications for each object.\n","\n","**This particular notebook enables object detection and classification on 2D images given ground truth bounding boxes. If you are interested in image segmentation, you should use our U-net or Stardist notebooks instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following papers: \n","\n","**YOLO9000: Better, Faster, Stronger** from Joseph Redmon and Ali Farhadi in Proceedings of the IEEE conference on computer vision and pattern recognition, 2017, (https://ieeexplore.ieee.org/document/8100173)\n","\n","**You Only Look Once: Unified, Real-Time Object Detection** from Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi in IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, (https://ieeexplore.ieee.org/document/7780460)\n","\n","**Note: The source code for this notebook is adapted for keras and can be found in: (https://github.com/experiencor/keras-yolo2)**\n","\n","\n","**Please also cite these original papers when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," Preparing the dataset carefully is essential to make this YOLOv2 notebook work. This model requires as input a set of images (currently .jpg) and as target a list of annotation files in Pascal VOC format. The annotation files should have the exact same name as the input files, except with an .xml instead of the .jpg extension. The annotation files contain the class labels and all bounding boxes for the objects for each image in your dataset. Most datasets will give the option of saving the annotations in this format or using software for hand-annotations will automatically save the annotations in this format. \n","\n"," If you want to assemble your own dataset we recommend using the open source https://www.makesense.ai/ resource. You can follow our instructions on how to label your dataset with this tool on our [wiki](https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki/Object-Detection-(YOLOv2)).\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .png or .jpg files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Input images (Training_source)\n"," - img_1.png, img_2.png, ...\n"," - High SNR images (Training_source_annotations)\n"," - img_1.xml, img_2.xml, ...\n"," - **Quality control dataset**\n"," - Input images\n"," - img_1.png, img_2.png\n"," - High SNR images\n"," - img_1.xml, img_2.xml\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install YOLOv2 and Dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = ['1.11']\n","\n","\n","#@markdown ##Install Network and Dependencies\n","%tensorflow_version 1.x\n","!pip install pascal-voc-writer\n","!pip install fpdf\n","!pip install PTable\n","\n","from pascal_voc_writer import Writer\n","from __future__ import division\n","from __future__ import print_function\n","from __future__ import absolute_import\n","import csv\n","import random\n","import pprint\n","import sys\n","import time\n","import numpy as np\n","from optparse import OptionParser\n","import pickle\n","import math\n","import cv2\n","import copy\n","import math\n","from matplotlib import pyplot as plt\n","import matplotlib.patches as patches\n","import tensorflow as tf\n","import pandas as pd\n","import os\n","import shutil\n","from skimage import io\n","from sklearn.metrics import average_precision_score\n","\n","from keras.models import Model\n","from keras.layers import Flatten, Dense, Input, Conv2D, MaxPooling2D, Dropout, Reshape, Activation, Conv2D, MaxPooling2D, BatchNormalization, Lambda\n","from keras.layers.advanced_activations import LeakyReLU\n","from keras.layers.merge import concatenate\n","from keras.applications.mobilenet import MobileNet\n","from keras.applications import InceptionV3\n","from keras.applications.vgg16 import VGG16\n","from keras.applications.resnet50 import ResNet50\n","\n","from keras import backend as K\n","from keras.optimizers import Adam, SGD, RMSprop\n","from keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, TimeDistributed\n","from keras.engine.topology import get_source_inputs\n","from keras.utils import layer_utils\n","from keras.utils.data_utils import get_file\n","from keras.objectives import categorical_crossentropy\n","from keras.models import Model\n","from keras.utils import generic_utils\n","from keras.engine import Layer, InputSpec\n","from keras import initializers, regularizers\n","from keras.utils import Sequence\n","import xml.etree.ElementTree as ET\n","from collections import OrderedDict, Counter\n","import json\n","import imageio\n","import imgaug as ia\n","from imgaug import augmenters as iaa\n","import copy\n","import cv2\n","from tqdm import tqdm\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess as sp\n","\n","from prettytable import from_csv\n","\n","# from matplotlib.pyplot import imread\n","\n","ia.seed(1)\n","# imgaug uses matplotlib backend for displaying images\n","from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage\n","import re\n","import glob\n","\n","#Here, we import a different github repo which includes the map_evaluation.py\n","!git clone https://github.com/rodrigo2019/keras_yolo2.git\n","\n","if os.path.exists('/content/gdrive/My Drive/keras-yolo2'):\n"," shutil.rmtree('/content/gdrive/My Drive/keras-yolo2')\n","\n","#Here, we import the main github repo for this notebook and move it to the gdrive\n","!git clone https://github.com/experiencor/keras-yolo2.git\n","shutil.move('/content/keras-yolo2','/content/gdrive/My Drive/keras-yolo2')\n","#Now, we move the map_evaluation.py file to the main repo for this notebook.\n","#The source repo of the map_evaluation.py can then be ignored and is not further relevant for this notebook.\n","shutil.move('/content/keras_yolo2/keras_yolov2/map_evaluation.py','/content/gdrive/My Drive/keras-yolo2/map_evaluation.py')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","\n","from backend import BaseFeatureExtractor, FullYoloFeature\n","from preprocessing import parse_annotation, BatchGenerator\n","\n","\n","\n","def plt_rectangle(plt,label,x1,y1,x2,y2,fontsize=10):\n"," '''\n"," == Input ==\n"," \n"," plt : matplotlib.pyplot object\n"," label : string containing the object class name\n"," x1 : top left corner x coordinate\n"," y1 : top left corner y coordinate\n"," x2 : bottom right corner x coordinate\n"," y2 : bottom right corner y coordinate\n"," '''\n"," linewidth = 1\n"," color = \"yellow\"\n"," plt.text(x1,y1,label,fontsize=fontsize,backgroundcolor=\"magenta\")\n"," plt.plot([x1,x1],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x2,x2],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y1,y1], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y2,y2], linewidth=linewidth,color=color)\n","\n","def extract_single_xml_file(tree,object_count=True):\n"," Nobj = 0\n"," row = OrderedDict()\n"," for elems in tree.iter():\n","\n"," if elems.tag == \"size\":\n"," for elem in elems:\n"," row[elem.tag] = int(elem.text)\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"name\":\n"," row[\"bbx_{}_{}\".format(Nobj,elem.tag)] = str(elem.text) \n"," if elem.tag == \"bndbox\":\n"," for k in elem:\n"," row[\"bbx_{}_{}\".format(Nobj,k.tag)] = float(k.text)\n"," Nobj += 1\n"," if object_count == True:\n"," row[\"Nobj\"] = Nobj\n"," return(row)\n","\n","def count_objects(tree):\n"," Nobj=0\n"," for elems in tree.iter():\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"bndbox\":\n"," Nobj += 1\n"," return(Nobj)\n","\n","def compute_overlap(a, b):\n"," \"\"\"\n"," Code originally from https://github.com/rbgirshick/py-faster-rcnn.\n"," Parameters\n"," ----------\n"," a: (N, 4) ndarray of float\n"," b: (K, 4) ndarray of float\n"," Returns\n"," -------\n"," overlaps: (N, K) ndarray of overlap between boxes and query_boxes\n"," \"\"\"\n"," area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])\n","\n"," iw = np.minimum(np.expand_dims(a[:, 2], axis=1), b[:, 2]) - np.maximum(np.expand_dims(a[:, 0], 1), b[:, 0])\n"," ih = np.minimum(np.expand_dims(a[:, 3], axis=1), b[:, 3]) - np.maximum(np.expand_dims(a[:, 1], 1), b[:, 1])\n","\n"," iw = np.maximum(iw, 0)\n"," ih = np.maximum(ih, 0)\n","\n"," ua = np.expand_dims((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), axis=1) + area - iw * ih\n","\n"," ua = np.maximum(ua, np.finfo(float).eps)\n","\n"," intersection = iw * ih\n","\n"," return intersection / ua\n","\n","def compute_ap(recall, precision):\n"," \"\"\" Compute the average precision, given the recall and precision curves.\n"," Code originally from https://github.com/rbgirshick/py-faster-rcnn.\n","\n"," # Arguments\n"," recall: The recall curve (list).\n"," precision: The precision curve (list).\n"," # Returns\n"," The average precision as computed in py-faster-rcnn.\n"," \"\"\"\n"," # correct AP calculation\n"," # first append sentinel values at the end\n"," mrec = np.concatenate(([0.], recall, [1.]))\n"," mpre = np.concatenate(([0.], precision, [0.]))\n","\n"," # compute the precision envelope\n"," for i in range(mpre.size - 1, 0, -1):\n"," mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])\n","\n"," # to calculate area under PR curve, look for points\n"," # where X axis (recall) changes value\n"," i = np.where(mrec[1:] != mrec[:-1])[0]\n","\n"," # and sum (\\Delta recall) * prec\n"," ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])\n"," return ap \n","\n","def load_annotation(image_folder,annotations_folder, i, config):\n"," annots = []\n"," imgs, anns = parse_annotation(annotations_folder,image_folder,config['model']['labels'])\n"," for obj in imgs[i]['object']:\n"," annot = [obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax'], config['model']['labels'].index(obj['name'])]\n"," annots += [annot]\n","\n"," if len(annots) == 0: annots = [[]]\n","\n"," return np.array(annots)\n","\n","def _calc_avg_precisions(config,image_folder,annotations_folder,weights_path,iou_threshold,score_threshold):\n","\n"," # gather all detections and annotations\n"," all_detections = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(image_folder)))]\n"," all_annotations = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(annotations_folder)))]\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," raw_image = cv2.imread(os.path.join(image_folder,sorted(os.listdir(image_folder))[i]))\n"," raw_height, raw_width, _ = raw_image.shape\n"," #print(raw_height)\n"," # make the boxes and the labels\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n"," yolo.load_weights(weights_path)\n"," pred_boxes = yolo.predict(raw_image,iou_threshold=iou_threshold,score_threshold=score_threshold)\n","\n"," score = np.array([box.score for box in pred_boxes])\n"," #print(score)\n"," pred_labels = np.array([box.label for box in pred_boxes])\n"," #print(len(pred_boxes))\n"," if len(pred_boxes) > 0:\n"," pred_boxes = np.array([[box.xmin * raw_width, box.ymin * raw_height, box.xmax * raw_width,\n"," box.ymax * raw_height, box.score] for box in pred_boxes])\n"," else:\n"," pred_boxes = np.array([[]])\n","\n"," # sort the boxes and the labels according to scores\n"," score_sort = np.argsort(-score)\n"," pred_labels = pred_labels[score_sort]\n"," pred_boxes = pred_boxes[score_sort]\n","\n"," # copy detections to all_detections\n"," for label in range(len(config['model']['labels'])):\n"," all_detections[i][label] = pred_boxes[pred_labels == label, :]\n","\n"," annotations = load_annotation(image_folder,annotations_folder,i,config)\n","\n"," # copy ground truth to all_annotations\n"," for label in range(len(config['model']['labels'])):\n"," all_annotations[i][label] = annotations[annotations[:, 4] == label, :4].copy()\n","\n"," # compute mAP by comparing all detections and all annotations\n"," average_precisions = {}\n"," F1_scores = {}\n"," total_recall = []\n"," total_precision = []\n"," \n"," with open(QC_model_folder+\"/Quality Control/QC_results.csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"class\", \"false positive\", \"true positive\", \"false negative\", \"recall\", \"precision\", \"accuracy\", \"f1 score\", \"average_precision\"]) \n"," \n"," for label in range(len(config['model']['labels'])):\n"," false_positives = np.zeros((0,))\n"," true_positives = np.zeros((0,))\n"," scores = np.zeros((0,))\n"," num_annotations = 0.0\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," detections = all_detections[i][label]\n"," annotations = all_annotations[i][label]\n"," num_annotations += annotations.shape[0]\n"," detected_annotations = []\n","\n"," for d in detections:\n"," scores = np.append(scores, d[4])\n","\n"," if annotations.shape[0] == 0:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n"," continue\n","\n"," overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations)\n"," assigned_annotation = np.argmax(overlaps, axis=1)\n"," max_overlap = overlaps[0, assigned_annotation]\n","\n"," if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations:\n"," false_positives = np.append(false_positives, 0)\n"," true_positives = np.append(true_positives, 1)\n"," detected_annotations.append(assigned_annotation)\n"," else:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n","\n"," # no annotations -> AP for this class is 0 (is this correct?)\n"," if num_annotations == 0:\n"," average_precisions[label] = 0\n"," continue\n","\n"," # sort by score\n"," indices = np.argsort(-scores)\n"," false_positives = false_positives[indices]\n"," true_positives = true_positives[indices]\n","\n"," # compute false positives and true positives\n"," false_positives = np.cumsum(false_positives)\n"," true_positives = np.cumsum(true_positives)\n","\n"," # compute recall and precision\n"," recall = true_positives / num_annotations\n"," precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)\n"," total_recall.append(recall)\n"," total_precision.append(precision)\n"," #print(precision)\n"," # compute average precision\n"," average_precision = compute_ap(recall, precision)\n"," average_precisions[label] = average_precision\n","\n"," if len(precision) != 0:\n"," F1_score = 2*(precision[-1]*recall[-1]/(precision[-1]+recall[-1]))\n"," F1_scores[label] = F1_score\n"," writer.writerow([config['model']['labels'][label], str(int(false_positives[-1])), str(int(true_positives[-1])), str(int(num_annotations-true_positives[-1])), str(recall[-1]), str(precision[-1]), str(true_positives[-1]/num_annotations), str(F1_scores[label]), str(average_precisions[label])])\n"," else:\n"," F1_score = 0\n"," F1_scores[label] = F1_score\n"," writer.writerow([config['model']['labels'][label], str(0), str(0), str(0), str(0), str(0), str(0), str(F1_score), str(average_precisions[label])])\n"," return F1_scores, average_precisions, total_recall, total_precision\n","\n","\n","def show_frame(pred_bb, pred_classes, pred_conf, gt_bb, gt_classes, class_dict, background=np.zeros((512, 512, 3)), show_confidence=True):\n"," \"\"\"\n"," Here, we are adapting classes and functions from https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," \"\"\"\n"," Plot the boundingboxes\n"," :param pred_bb: (np.array) Predicted Bounding Boxes [x1, y1, x2, y2] : Shape [n_pred, 4]\n"," :param pred_classes: (np.array) Predicted Classes : Shape [n_pred]\n"," :param pred_conf: (np.array) Predicted Confidences [0.-1.] : Shape [n_pred]\n"," :param gt_bb: (np.array) Ground Truth Bounding Boxes [x1, y1, x2, y2] : Shape [n_gt, 4]\n"," :param gt_classes: (np.array) Ground Truth Classes : Shape [n_gt]\n"," :param class_dict: (dictionary) Key value pairs of classes, e.g. {0:'dog',1:'cat',2:'horse'}\n"," :return:\n"," \"\"\"\n"," n_pred = pred_bb.shape[0]\n"," n_gt = gt_bb.shape[0]\n"," n_class = int(np.max(np.append(pred_classes, gt_classes)) + 1)\n"," #print(n_class)\n"," if len(background.shape) < 3:\n"," h, w = background.shape\n"," else:\n"," h, w, c = background.shape\n","\n"," ax = plt.subplot(\"111\")\n"," ax.imshow(background)\n"," cmap = plt.cm.get_cmap('hsv')\n","\n"," confidence_alpha = pred_conf.copy()\n"," if not show_confidence:\n"," confidence_alpha.fill(1)\n","\n"," for i in range(n_pred):\n"," x1 = pred_bb[i, 0]# * w\n"," y1 = pred_bb[i, 1]# * h\n"," x2 = pred_bb[i, 2]# * w\n"," y2 = pred_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," #print(x1, y1)\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(pred_classes[i]) / n_class),\n"," linestyle='dashdot',\n"," alpha=confidence_alpha[i]))\n","\n"," for i in range(n_gt):\n"," x1 = gt_bb[i, 0]# * w\n"," y1 = gt_bb[i, 1]# * h\n"," x2 = gt_bb[i, 2]# * w\n"," y2 = gt_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(gt_classes[i]) / n_class)))\n","\n"," legend_handles = []\n","\n"," for i in range(n_class):\n"," legend_handles.append(patches.Patch(color=cmap(float(i) / n_class), label=class_dict[i]))\n"," \n"," ax.legend(handles=legend_handles)\n"," plt.show()\n","\n","class BoundBox:\n"," \"\"\"\n"," Here, we are adapting classes and functions from https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," def __init__(self, xmin, ymin, xmax, ymax, c = None, classes = None):\n"," self.xmin = xmin\n"," self.ymin = ymin\n"," self.xmax = xmax\n"," self.ymax = ymax\n"," \n"," self.c = c\n"," self.classes = classes\n","\n"," self.label = -1\n"," self.score = -1\n","\n"," def get_label(self):\n"," if self.label == -1:\n"," self.label = np.argmax(self.classes)\n"," \n"," return self.label\n"," \n"," def get_score(self):\n"," if self.score == -1:\n"," self.score = self.classes[self.get_label()]\n"," \n"," return self.score\n","\n","class WeightReader:\n"," def __init__(self, weight_file):\n"," self.offset = 4\n"," self.all_weights = np.fromfile(weight_file, dtype='float32')\n"," \n"," def read_bytes(self, size):\n"," self.offset = self.offset + size\n"," return self.all_weights[self.offset-size:self.offset]\n"," \n"," def reset(self):\n"," self.offset = 4\n","\n","def bbox_iou(box1, box2):\n"," intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])\n"," intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax]) \n"," \n"," intersect = intersect_w * intersect_h\n","\n"," w1, h1 = box1.xmax-box1.xmin, box1.ymax-box1.ymin\n"," w2, h2 = box2.xmax-box2.xmin, box2.ymax-box2.ymin\n"," \n"," union = w1*h1 + w2*h2 - intersect\n"," \n"," return float(intersect) / union\n","\n","def draw_boxes(image, boxes, labels):\n"," image_h, image_w, _ = image.shape\n"," #Changes in box color added by LvC\n"," # class_colours = []\n"," # for c in range(len(labels)):\n"," # colour = np.random.randint(low=0,high=255,size=3).tolist()\n"," # class_colours.append(tuple(colour))\n"," for box in boxes:\n"," xmin = int(box.xmin*image_w)\n"," ymin = int(box.ymin*image_h)\n"," xmax = int(box.xmax*image_w)\n"," ymax = int(box.ymax*image_h)\n"," if box.get_label() == 0:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (255,0,0), 3)\n"," elif box.get_label() == 1:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,255,0), 3)\n"," else:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,0,255), 3)\n"," #cv2.rectangle(image, (xmin,ymin), (xmax,ymax), class_colours[box.get_label()], 3)\n"," cv2.putText(image, \n"," labels[box.get_label()] + ' ' + str(round(box.get_score(),3)), \n"," (xmin, ymin - 13), \n"," cv2.FONT_HERSHEY_SIMPLEX, \n"," 1e-3 * image_h, \n"," (0,0,0), 2)\n"," #print(box.get_label()) \n"," return image \n","\n","#Function added by LvC\n","def save_boxes(image_path, boxes, labels):#, save_path):\n"," image = cv2.imread(image_path)\n"," image_h, image_w, _ = image.shape\n"," save_boxes =[]\n"," save_boxes_names = []\n"," save_boxes.append(os.path.basename(image_path))\n"," save_boxes_names.append(os.path.basename(image_path))\n"," for box in boxes:\n"," # xmin = box.xmin\n"," save_boxes.append(int(box.xmin*image_w))\n"," save_boxes_names.append(int(box.xmin*image_w))\n"," # ymin = box.ymin\n"," save_boxes.append(int(box.ymin*image_h))\n"," save_boxes_names.append(int(box.ymin*image_h))\n"," # xmax = box.xmax\n"," save_boxes.append(int(box.xmax*image_w))\n"," save_boxes_names.append(int(box.xmax*image_w))\n"," # ymax = box.ymax\n"," save_boxes.append(int(box.ymax*image_h))\n"," save_boxes_names.append(int(box.ymax*image_h))\n"," score = box.get_score()\n"," save_boxes.append(score)\n"," save_boxes_names.append(score)\n"," label = box.get_label()\n"," save_boxes.append(label)\n"," save_boxes_names.append(labels[label])\n"," \n"," #This file will be for later analysis of the bounding boxes in imagej\n"," if not os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," with open('/content/predicted_bounding_boxes.csv', 'w', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes)\n"," else:\n"," with open('/content/predicted_bounding_boxes.csv', 'a+', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile)\n"," csvwriter.writerow(save_boxes)\n"," \n"," if not os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," with open('/content/predicted_bounding_boxes_names.csv', 'w', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes_names)\n"," else:\n"," with open('/content/predicted_bounding_boxes_names.csv', 'a+', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names)\n"," csvwriter.writerow(save_boxes_names)\n"," # #This file is to create a nicer display for the output images\n"," # if not os.path.exists('/content/predicted_bounding_boxes_display.csv'):\n"," # with open('/content/predicted_bounding_boxes_display.csv', 'w', newline='') as csvfile_new:\n"," # csvwriter2 = csv.writer(csvfile_new, delimiter=',')\n"," # specs_list = ['filename','width','height','class','xmin','ymin','xmax','ymax']\n"," # csvwriter2.writerow(specs_list)\n"," # else:\n"," # with open('/content/predicted_bounding_boxes_display.csv','a+',newline='') as csvfile_new:\n"," # csvwriter2 = csv.writer(csvfile_new)\n"," # for box in boxes:\n"," # row = [os.path.basename(image_path),image_w,image_h,box.get_label(),int(box.xmin*image_w),int(box.ymin*image_h),int(box.xmax*image_w),int(box.ymax*image_h)]\n"," # csvwriter2.writerow(row)\n","\n","def add_header(inFilePath,outFilePath):\n"," header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n"," with open(inFilePath, newline='') as inFile, open(outFilePath, 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n"," \n","def decode_netout(netout, anchors, nb_class, obj_threshold=0.3, nms_threshold=0.5):\n"," grid_h, grid_w, nb_box = netout.shape[:3]\n","\n"," boxes = []\n"," \n"," # decode the output by the network\n"," netout[..., 4] = _sigmoid(netout[..., 4])\n"," netout[..., 5:] = netout[..., 4][..., np.newaxis] * _softmax(netout[..., 5:])\n"," netout[..., 5:] *= netout[..., 5:] > obj_threshold\n"," \n"," for row in range(grid_h):\n"," for col in range(grid_w):\n"," for b in range(nb_box):\n"," # from 4th element onwards are confidence and class classes\n"," classes = netout[row,col,b,5:]\n"," \n"," if np.sum(classes) > 0:\n"," # first 4 elements are x, y, w, and h\n"," x, y, w, h = netout[row,col,b,:4]\n","\n"," x = (col + _sigmoid(x)) / grid_w # center position, unit: image width\n"," y = (row + _sigmoid(y)) / grid_h # center position, unit: image height\n"," w = anchors[2 * b + 0] * np.exp(w) / grid_w # unit: image width\n"," h = anchors[2 * b + 1] * np.exp(h) / grid_h # unit: image height\n"," confidence = netout[row,col,b,4]\n"," \n"," box = BoundBox(x-w/2, y-h/2, x+w/2, y+h/2, confidence, classes)\n"," \n"," boxes.append(box)\n","\n"," # suppress non-maximal boxes\n"," for c in range(nb_class):\n"," sorted_indices = list(reversed(np.argsort([box.classes[c] for box in boxes])))\n","\n"," for i in range(len(sorted_indices)):\n"," index_i = sorted_indices[i]\n"," \n"," if boxes[index_i].classes[c] == 0: \n"," continue\n"," else:\n"," for j in range(i+1, len(sorted_indices)):\n"," index_j = sorted_indices[j]\n"," \n"," if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_threshold:\n"," boxes[index_j].classes[c] = 0\n"," \n"," # remove the boxes which are less likely than a obj_threshold\n"," boxes = [box for box in boxes if box.get_score() > obj_threshold]\n"," \n"," return boxes\n","\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","with open(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," reduce_lr = False\n"," for line in lineReader:\n"," if \"reduce_lr\" in line:\n"," reduce_lr = True\n"," break\n","\n","if reduce_lr == False:\n"," #replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n csv_logger=CSVLogger('/content/training_evaluation.csv')\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n reduce_lr=ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1)\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"import EarlyStopping\",\"import ReduceLROnPlateau, EarlyStopping\")\n","\n","with open(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," map_eval = False\n"," for line in lineReader:\n"," if \"map_evaluation\" in line:\n"," map_eval = True\n"," break\n","\n","if map_eval == False:\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"import cv2\",\"import cv2\\nfrom map_evaluation import MapEvaluation\")\n"," new_callback = ' map_evaluator = MapEvaluation(self, valid_generator,save_best=True,save_name=\"/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5\",iou_threshold=0.3,score_threshold=0.3)'\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"write_images=False)\",\"write_images=False)\\n\"+new_callback)\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"import keras\",\"import keras\\nimport csv\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"from .utils\",\"from utils\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\".format(_map))\",\".format(_map))\\n with open('/content/gdrive/My Drive/mAP.csv','a+', newline='') as mAP_csv:\\n csv_writer=csv.writer(mAP_csv)\\n csv_writer.writerow(['mAP:','{:.4f}'.format(_map)])\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"iou_threshold=0.5\",\"iou_threshold=0.3\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"score_threshold=0.5\",\"score_threshold=0.3\")\n","\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"[early_stop, checkpoint, tensorboard]\",\"[checkpoint, reduce_lr, map_evaluator]\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"predict(self, image)\",\"predict(self,image,iou_threshold=0.3,score_threshold=0.3)\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"self.model.summary()\",\"#self.model.summary()\")\n","from frontend import YOLO\n","\n","def train(config_path, model_path, percentage_validation):\n"," #config_path = args.conf\n","\n"," with open(config_path) as config_buffer: \n"," config = json.loads(config_buffer.read())\n","\n"," ###############################\n"," # Parse the annotations \n"," ###############################\n","\n"," # parse annotations of the training set\n"," train_imgs, train_labels = parse_annotation(config['train']['train_annot_folder'], \n"," config['train']['train_image_folder'], \n"," config['model']['labels'])\n","\n"," # parse annotations of the validation set, if any, otherwise split the training set\n"," if os.path.exists(config['valid']['valid_annot_folder']):\n"," valid_imgs, valid_labels = parse_annotation(config['valid']['valid_annot_folder'], \n"," config['valid']['valid_image_folder'], \n"," config['model']['labels'])\n"," else:\n"," train_valid_split = int((1-percentage_validation/100.)*len(train_imgs))\n"," np.random.shuffle(train_imgs)\n","\n"," valid_imgs = train_imgs[train_valid_split:]\n"," train_imgs = train_imgs[:train_valid_split]\n","\n"," if len(config['model']['labels']) > 0:\n"," overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys()))\n","\n"," print('Seen labels:\\t', train_labels)\n"," print('Given labels:\\t', config['model']['labels'])\n"," print('Overlap labels:\\t', overlap_labels) \n","\n"," if len(overlap_labels) < len(config['model']['labels']):\n"," print('Some labels have no annotations! Please revise the list of labels in the config.json file!')\n"," return\n"," else:\n"," print('No labels are provided. Train on all seen labels.')\n"," config['model']['labels'] = train_labels.keys()\n"," \n"," ###############################\n"," # Construct the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load the pretrained weights (if any) \n"," ############################### \n","\n"," if os.path.exists(config['train']['pretrained_weights']):\n"," print(\"Loading pre-trained weights in\", config['train']['pretrained_weights'])\n"," yolo.load_weights(config['train']['pretrained_weights'])\n"," if os.path.exists('/content/gdrive/My Drive/mAP.csv'):\n"," os.remove('/content/gdrive/My Drive/mAP.csv')\n"," ###############################\n"," # Start the training process \n"," ###############################\n","\n"," yolo.train(train_imgs = train_imgs,\n"," valid_imgs = valid_imgs,\n"," train_times = config['train']['train_times'],\n"," valid_times = config['valid']['valid_times'],\n"," nb_epochs = config['train']['nb_epochs'], \n"," learning_rate = config['train']['learning_rate'], \n"," batch_size = config['train']['batch_size'],\n"," warmup_epochs = config['train']['warmup_epochs'],\n"," object_scale = config['train']['object_scale'],\n"," no_object_scale = config['train']['no_object_scale'],\n"," coord_scale = config['train']['coord_scale'],\n"," class_scale = config['train']['class_scale'],\n"," saved_weights_name = config['train']['saved_weights_name'],\n"," debug = config['train']['debug'])\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(model_path,'Quality Control/training_evaluation.csv')\n"," with open(lossDataCSVpath, 'w') as f1:\n"," writer = csv.writer(f1)\n"," mAP_df = pd.read_csv('/content/gdrive/My Drive/mAP.csv',header=None)\n"," writer.writerow(['loss','val_loss','mAP','learning rate'])\n"," for i in range(len(yolo.model.history.history['loss'])):\n"," writer.writerow([yolo.model.history.history['loss'][i], yolo.model.history.history['val_loss'][i], float(mAP_df[1][i]), yolo.model.history.history['lr'][i]])\n","\n"," yolo.model.save(model_path+'/last_weights.h5')\n","\n","def predict(config, weights_path, image_path):#, model_path):\n","\n"," with open(config) as config_buffer: \n"," config = json.load(config_buffer)\n","\n"," ###############################\n"," # Make the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load trained weights\n"," ############################### \n","\n"," yolo.load_weights(weights_path)\n","\n"," ###############################\n"," # Predict bounding boxes \n"," ###############################\n","\n"," if image_path[-4:] == '.mp4':\n"," video_out = image_path[:-4] + '_detected' + image_path[-4:]\n"," video_reader = cv2.VideoCapture(image_path)\n","\n"," nb_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))\n"," frame_h = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))\n"," frame_w = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))\n","\n"," video_writer = cv2.VideoWriter(video_out,\n"," cv2.VideoWriter_fourcc(*'MPEG'), \n"," 50.0, \n"," (frame_w, frame_h))\n","\n"," for i in tqdm(range(nb_frames)):\n"," _, image = video_reader.read()\n"," \n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n","\n"," video_writer.write(np.uint8(image))\n","\n"," video_reader.release()\n"," video_writer.release() \n"," else:\n"," image = cv2.imread(image_path)\n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n"," save_boxes(image_path,boxes,config['model']['labels'])#,model_path)#added by LvC\n"," print(len(boxes), 'boxes are found')\n"," #print(image)\n"," cv2.imwrite(image_path[:-4] + '_detected' + image_path[-4:], image)\n"," \n"," return len(boxes)\n","\n","# function to convert BoundingBoxesOnImage object into DataFrame\n","def bbs_obj_to_df(bbs_object):\n","# convert BoundingBoxesOnImage object into array\n"," bbs_array = bbs_object.to_xyxy_array()\n","# convert array into a DataFrame ['xmin', 'ymin', 'xmax', 'ymax'] columns\n"," df_bbs = pd.DataFrame(bbs_array, columns=['xmin', 'ymin', 'xmax', 'ymax'])\n"," return df_bbs\n","\n","# Function that will extract column data for our CSV file\n","def xml_to_csv(path):\n"," xml_list = []\n"," for xml_file in glob.glob(path + '/*.xml'):\n"," tree = ET.parse(xml_file)\n"," root = tree.getroot()\n"," for member in root.findall('object'):\n"," value = (root.find('filename').text,\n"," int(root.find('size')[0].text),\n"," int(root.find('size')[1].text),\n"," member[0].text,\n"," int(member[4][0].text),\n"," int(member[4][1].text),\n"," int(member[4][2].text),\n"," int(member[4][3].text)\n"," )\n"," xml_list.append(value)\n"," column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," xml_df = pd.DataFrame(xml_list, columns=column_name)\n"," return xml_df\n","\n","\n","\n","def image_aug(df, images_path, aug_images_path, image_prefix, augmentor):\n"," # create data frame which we're going to populate with augmented image info\n"," aug_bbs_xy = pd.DataFrame(columns=\n"," ['filename','width','height','class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," )\n"," grouped = df.groupby('filename')\n"," \n"," for filename in df['filename'].unique():\n"," # get separate data frame grouped by file name\n"," group_df = grouped.get_group(filename)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1) \n"," # read the image\n"," image = imageio.imread(images_path+filename)\n"," # get bounding boxes coordinates and write into array \n"," bb_array = group_df.drop(['filename', 'width', 'height', 'class'], axis=1).values\n"," # pass the array of bounding boxes coordinates to the imgaug library\n"," bbs = BoundingBoxesOnImage.from_xyxy_array(bb_array, shape=image.shape)\n"," # apply augmentation on image and on the bounding boxes\n"," image_aug, bbs_aug = augmentor(image=image, bounding_boxes=bbs)\n"," # disregard bounding boxes which have fallen out of image pane \n"," bbs_aug = bbs_aug.remove_out_of_image()\n"," # clip bounding boxes which are partially outside of image pane\n"," bbs_aug = bbs_aug.clip_out_of_image()\n"," \n"," # don't perform any actions with the image if there are no bounding boxes left in it \n"," if re.findall('Image...', str(bbs_aug)) == ['Image([]']:\n"," pass\n"," \n"," # otherwise continue\n"," else:\n"," # write augmented image to a file\n"," imageio.imwrite(aug_images_path+image_prefix+filename, image_aug) \n"," # create a data frame with augmented values of image width and height\n"," info_df = group_df.drop(['xmin', 'ymin', 'xmax', 'ymax'], axis=1) \n"," for index, _ in info_df.iterrows():\n"," info_df.at[index, 'width'] = image_aug.shape[1]\n"," info_df.at[index, 'height'] = image_aug.shape[0]\n"," # rename filenames by adding the predifined prefix\n"," info_df['filename'] = info_df['filename'].apply(lambda x: image_prefix+x)\n"," # create a data frame with augmented bounding boxes coordinates using the function we created earlier\n"," bbs_df = bbs_obj_to_df(bbs_aug)\n"," # concat all new augmented info into new data frame\n"," aug_df = pd.concat([info_df, bbs_df], axis=1)\n"," # append rows to aug_bbs_xy data frame\n"," aug_bbs_xy = pd.concat([aug_bbs_xy, aug_df]) \n"," \n"," # return dataframe with updated images and bounding boxes annotations \n"," aug_bbs_xy = aug_bbs_xy.reset_index()\n"," aug_bbs_xy = aug_bbs_xy.drop(['index'], axis=1)\n"," return aug_bbs_xy\n","\n","\n","print('-------------------------------------------')\n","print(\"Depencies installed and imported.\")\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m'\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\"+bcolors.NORMAL)\n","\n","\n","\n","# Exporting requirements.txt for local run\n","!pip freeze > requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your paths and parameters**\n","\n","---\n","\n","The code below allows the user to enter the paths to where the training data is and to define the training parameters.\n","\n","After playing the cell will display some quantitative metrics of your dataset, including a count of objects per image and the number of instances per class.\n"]},{"cell_type":"markdown","metadata":{"id":"grFtuWsY5LZm"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_source_annotations`:** These are the paths to your folders containing the Training_source and the annotation data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Give estimates for training performance given a number of epochs and provide a default value. **Default value: 27**\n","\n","**Note that YOLOv2 uses 3 Warm-up epochs which improves the model's performance. This means the network will train for number_of_epochs + 3 epochs.**\n","\n","**`backend`:** There are different backends which are available to be trained for YOLO. These are usually slightly different model architectures, with pretrained weights. Take a look at the available backends and research which one will be best suited for your dataset.\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`train_times:`**Input how many times to cycle through the dataset per epoch. This is more useful for smaller datasets (but risks overfitting). **Default value: 4**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`learning_rate:`** Input the initial value to be used as learning rate. **Default value: 0.0004**\n","\n","**`false_negative_penalty:`** Penalize wrong detection of 'no-object'. **Default: 5.0**\n","\n","**`false_positive_penalty:`** Penalize wrong detection of 'object'. **Default: 1.0**\n","\n","**`position_size_penalty:`** Penalize inaccurate positioning or size of bounding boxes. **Default:1.0**\n","\n","**`false_class_penalty:`** Penalize misclassification of object in bounding box. **Default: 1.0**\n","\n","**`percentage_validation:`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** "]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images:\n","\n","Training_Source = \"\" #@param {type:\"string\"}\n","\n","# Ground truth images\n","Training_Source_annotations = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# backend\n","#@markdown ###Choose a backend\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","\n","\n","full_model_path = os.path.join(model_path,model_name)\n","if os.path.exists(full_model_path):\n"," print(bcolors.WARNING+'Model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","\n","# other parameters for training.\n","# @markdown ###Training Parameters\n","# @markdown Number of epochs:\n","\n","number_of_epochs = 27#@param {type:\"number\"}\n","\n","# !sed -i 's@\\\"nb_epochs\\\":.*,@\\\"nb_epochs\\\": $number_of_epochs,@g' config.json\n","\n","# #@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","train_times = 4 #@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"number\"}\n","learning_rate = 1e-4 #@param{type:\"number\"}\n","false_negative_penalty = 5.0 #@param{type:\"number\"}\n","false_positive_penalty = 1.0 #@param{type:\"number\"}\n","position_size_penalty = 1.0 #@param{type:\"number\"}\n","false_class_penalty = 1.0 #@param{type:\"number\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," train_times = 4\n"," batch_size = 8\n"," learning_rate = 1e-4\n"," false_negative_penalty = 5.0\n"," false_positive_penalty = 1.0\n"," position_size_penalty = 1.0\n"," false_class_penalty = 1.0\n"," percentage_validation = 10\n","\n","\n","df_anno = []\n","dir_anno = Training_Source_annotations\n","for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n","df_anno = pd.DataFrame(df_anno)\n","\n","maxNobj = np.max(df_anno[\"Nobj\"])\n","totalNobj = np.sum(df_anno[\"Nobj\"])\n","\n","\n","class_obj = []\n","for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n","class_obj = np.array(class_obj)\n","\n","count = Counter(class_obj[class_obj != 'nan'])\n","print(count)\n","class_nm = list(count.keys())\n","class_labels = json.dumps(class_nm)\n","class_count = list(count.values())\n","asort_class_count = np.argsort(class_count)\n","\n","class_nm = np.array(class_nm)[asort_class_count]\n","class_count = np.array(class_count)[asort_class_count]\n","\n","xs = range(len(class_count))\n","\n","\n","#Show how many objects there are in the images\n","plt.figure(figsize=(15,8))\n","plt.subplot(1,2,1)\n","plt.hist(df_anno[\"Nobj\"].values,bins=50)\n","plt.title(\"Total number of objects in the dataset: {}\".format(totalNobj))\n","plt.xlabel('Number of objects per image')\n","plt.ylabel('Occurences')\n","\n","plt.subplot(1,2,2)\n","plt.barh(xs,class_count)\n","plt.yticks(xs,class_nm)\n","plt.title(\"The number of objects per class: {} classes in total\".format(len(count)))\n","plt.show()\n","\n","\n","visualise_example = False\n","Use_pretrained_model = False\n","Use_Data_augmentation = False\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"NXxj-Xi3Kang"},"source":["#@markdown ###Play this cell to visualise an example image from your dataset to make sure annotations and images are properly matched.\n","import imageio\n","visualise_example = True\n","size = 1 \n","ind_random = np.random.randint(0,df_anno.shape[0],size=size)\n","img_dir=Training_Source\n","\n","file_suffix = os.path.splitext(os.listdir(Training_Source)[0])[1]\n","for irow in ind_random:\n"," row = df_anno.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img, cmap='gray') # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.axis('off')\n"," plt.savefig('/content/TrainingDataExample_YOLOv2.png',bbox_inches='tight',pad_inches=0)\n"," plt.show() ## show the plot"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"eik5zLKWpN_O"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset the `Use_Data_Augmentation` box can be unticked.\n","\n","Here, the images and bounding boxes are augmented by flipping and rotation. When doubling the dataset the images are only flipped. With each higher factor of augmentation the images added to the dataset represent one further rotation to the right by 90 degrees. 8x augmentation will give a dataset that is fully rotated and flipped once."]},{"cell_type":"code","metadata":{"id":"RmTSfMO-pNMc","cellView":"form"},"source":["#@markdown ##**Augmentation Options**\n","\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","multiply_dataset_by = 2 #@param {type:\"slider\", min:2, max:8, step:1}\n","\n","rotation_range = 90\n","\n","if (Use_Data_augmentation):\n"," print('Data Augmentation enabled')\n"," # load images as NumPy arrays and append them to images list\n"," if os.path.exists(Training_Source+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Training_Source+'/.ipynb_checkpoints')\n"," \n"," images = []\n"," for index, file in enumerate(glob.glob(Training_Source+'/*'+file_suffix)):\n"," images.append(imageio.imread(file))\n"," \n"," # how many images we have\n"," print('Augmenting {} images'.format(len(images)))\n","\n"," # apply xml_to_csv() function to convert all XML files in images/ folder into labels.csv\n"," labels_df = xml_to_csv(Training_Source_annotations)\n"," labels_df.to_csv(('/content/original_labels.csv'), index=None)\n"," \n"," # Apply flip augmentation\n"," aug = iaa.OneOf([ \n"," iaa.Fliplr(1),\n"," iaa.Flipud(1)\n"," ])\n"," aug_2 = iaa.Affine(rotate=rotation_range, fit_output=True)\n"," aug_3 = iaa.Affine(rotate=rotation_range*2, fit_output=True)\n"," aug_4 = iaa.Affine(rotate=rotation_range*3, fit_output=True)\n","\n"," #Here we create a folder that will hold the original image dataset and the augmented image dataset\n"," augmented_training_source = os.path.dirname(Training_Source)+'/'+os.path.basename(Training_Source)+'_augmentation'\n"," if os.path.exists(augmented_training_source):\n"," shutil.rmtree(augmented_training_source)\n"," os.mkdir(augmented_training_source)\n","\n"," #Here we create a folder that will hold the original image annotation dataset and the augmented image annotation dataset (the bounding boxes).\n"," augmented_training_source_annotation = os.path.dirname(Training_Source_annotations)+'/'+os.path.basename(Training_Source_annotations)+'_augmentation'\n"," if os.path.exists(augmented_training_source_annotation):\n"," shutil.rmtree(augmented_training_source_annotation)\n"," os.mkdir(augmented_training_source_annotation)\n","\n"," #Create the augmentation\n"," augmented_images_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'flip_', aug)\n"," \n"," # Concat resized_images_df and augmented_images_df together and save in a new all_labels.csv file\n"," all_labels_df = pd.concat([labels_df, augmented_images_df])\n"," all_labels_df.to_csv('/content/combined_labels.csv', index=False)\n","\n"," #Here we convert the new bounding boxes for the augmented images to PASCAL VOC .xml format\n"," def convert_to_xml(df,source,target_folder):\n"," grouped = df.groupby('filename')\n"," for file in os.listdir(source):\n"," #if file in grouped.filename:\n"," group_df = grouped.get_group(file)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1)\n"," #group_df = group_df.dropna(axis=0)\n"," writer = Writer(source+'/'+file,group_df.iloc[1]['width'],group_df.iloc[1]['height'])\n"," for i, row in group_df.iterrows():\n"," writer.addObject(row['class'],round(row['xmin']),round(row['ymin']),round(row['xmax']),round(row['ymax']))\n"," writer.save(target_folder+'/'+os.path.splitext(file)[0]+'.xml')\n"," convert_to_xml(all_labels_df,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #Second round of augmentation\n"," if multiply_dataset_by > 2:\n"," aug_labels_df_2 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_2_df = image_aug(aug_labels_df_2, augmented_training_source+'/', augmented_training_source+'/', 'rot1_90_', aug_2)\n"," all_aug_labels_df = pd.concat([augmented_images_df, augmented_images_2_df])\n"," #all_labels_df.to_csv('/content/all_labels_aug.csv', index=False)\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 3:\n"," print('Augmenting again')\n"," aug_labels_df_3 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_3_df = image_aug(aug_labels_df_3, augmented_training_source+'/', augmented_training_source+'/', 'rot2_90_', aug_2)\n"," all_aug_labels_df_3 = pd.concat([all_aug_labels_df, augmented_images_3_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_3,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #This is a preliminary remover of potential duplicates in the augmentation\n"," #Ideally, duplicates are not even produced, but this acts as a fail safe.\n"," if multiply_dataset_by==4:\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n"," if multiply_dataset_by > 4:\n"," print('And Again')\n"," aug_labels_df_4 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_4_df = image_aug(aug_labels_df_4, augmented_training_source+'/',augmented_training_source+'/','rot3_90_', aug_2)\n"," all_aug_labels_df_4 = pd.concat([all_aug_labels_df_3, augmented_images_4_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_4,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot3_90_rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_rot1_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n","\n"," if multiply_dataset_by > 5:\n"," print('And again')\n"," augmented_images_5_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_90_', aug_2)\n"," all_aug_labels_df_5 = pd.concat([all_aug_labels_df_4,augmented_images_5_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," \n"," convert_to_xml(all_aug_labels_df_5,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 6:\n"," print('And again')\n"," augmented_images_df_6 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_180_', aug_3)\n"," all_aug_labels_df_6 = pd.concat([all_aug_labels_df_5,augmented_images_df_6])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_6,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 7:\n"," print('And again')\n"," augmented_images_df_7 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_270_', aug_4)\n"," all_aug_labels_df_7 = pd.concat([all_aug_labels_df_6,augmented_images_df_7])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_7,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(Training_Source):\n"," shutil.copyfile(Training_Source+'/'+file,augmented_training_source+'/'+file)\n"," shutil.copyfile(Training_Source_annotations+'/'+os.path.splitext(file)[0]+'.xml',augmented_training_source_annotation+'/'+os.path.splitext(file)[0]+'.xml')\n"," # display new dataframe\n"," #augmented_images_df\n"," \n"," # os.chdir('/content/gdrive/My Drive/keras-yolo2')\n"," # #Change the name of the training folder\n"," # !sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$augmented_training_source/\\\",@g' config.json\n","\n"," # #Change annotation folder\n"," # !sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$augmented_training_source_annotation/\\\",@g' config.json\n","\n"," df_anno = []\n"," dir_anno = augmented_training_source_annotation\n"," for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n"," df_anno = pd.DataFrame(df_anno)\n","\n"," maxNobj = np.max(df_anno[\"Nobj\"])\n","\n"," #Write the annotations to a csv file\n"," #df_anno.to_csv(model_path+'/annot.csv', index=False)#header=False, sep=',')\n","\n"," #Show how many objects there are in the images\n"," plt.figure()\n"," plt.subplot(2,1,1)\n"," plt.hist(df_anno[\"Nobj\"].values,bins=50)\n"," plt.title(\"max N of objects per image={}\".format(maxNobj))\n"," plt.show()\n","\n"," #Show the classes and how many there are of each in the dataset\n"," class_obj = []\n"," for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n"," class_obj = np.array(class_obj)\n","\n"," count = Counter(class_obj[class_obj != 'nan'])\n"," print(count)\n"," class_nm = list(count.keys())\n"," class_labels = json.dumps(class_nm)\n"," class_count = list(count.values())\n"," asort_class_count = np.argsort(class_count)\n","\n"," class_nm = np.array(class_nm)[asort_class_count]\n"," class_count = np.array(class_count)[asort_class_count]\n","\n"," xs = range(len(class_count))\n","\n"," plt.subplot(2,1,2)\n"," plt.barh(xs,class_count)\n"," plt.yticks(xs,class_nm)\n"," plt.title(\"The number of objects per class: {} objects in total\".format(len(count)))\n"," plt.show()\n","\n","else:\n"," print('No augmentation will be used')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"tZvcYmxTdXQm","cellView":"form"},"source":["#@markdown ###Play this cell to visualise some example images from your **augmented** dataset to make sure annotations and images are properly matched.\n","if (Use_Data_augmentation):\n"," df_anno_aug = []\n"," dir_anno_aug = augmented_training_source_annotation\n"," for fnm in os.listdir(dir_anno_aug): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno_aug,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_aug.append(row)\n"," df_anno_aug = pd.DataFrame(df_anno_aug)\n","\n"," size = 3 \n"," ind_random = np.random.randint(0,df_anno_aug.shape[0],size=size)\n"," img_dir=augmented_training_source\n","\n"," file_suffix = os.path.splitext(os.listdir(augmented_training_source)[0])[1]\n"," for irow in ind_random:\n"," row = df_anno_aug.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img, cmap='gray') # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.show() ## show the plot\n"," print('These are the augmented training images.')\n","\n","else:\n"," print('Data augmentation disabled.')\n","\n","# else:\n","# for irow in ind_random:\n","# row = df_anno.iloc[irow,:]\n","# path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n","# # read in image\n","# img = imageio.imread(path)\n","\n","# plt.figure(figsize=(12,12))\n","# plt.imshow(img, cmap='gray') # plot image\n","# plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n","# # for each object in the image, plot the bounding box\n","# for iplot in range(row[\"Nobj\"]):\n","# plt_rectangle(plt,\n","# label = row[\"bbx_{}_name\".format(iplot)],\n","# x1=row[\"bbx_{}_xmin\".format(iplot)],\n","# y1=row[\"bbx_{}_ymin\".format(iplot)],\n","# x2=row[\"bbx_{}_xmax\".format(iplot)],\n","# y2=row[\"bbx_{}_ymax\".format(iplot)])\n","# plt.show() ## show the plot\n","# print('These are the non-augmented training images.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ud_Sx7MT5f4_"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a YOLOv2 model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"_cvRRrStGe3y","cellView":"form"},"source":["# @markdown ##Loading weights from a pretrained network\n","\n","# Training_Source = \"\" #@param{type:\"string\"}\n","# Training_Source_annotation = \"\" #@param{type:\"string\"}\n","# Check if the right files exist\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","pretrained_model_path = \"\" #@param{type:\"string\"}\n","h5_file_path = pretrained_model_path+'/'+Weights_choice+'_weights.h5'\n","\n","if not os.path.exists(h5_file_path) and Use_pretrained_model:\n"," print('WARNING pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n","# os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","# !sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"$h5_file_path\\\",@g' config.json\n","\n","if Use_pretrained_model:\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4):\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," learning_rate = bestLearningRate\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," #bestLearningRate = learning_rate\n"," #lastLearningRate = learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","else:\n"," print('No pre-trained models will be used.')\n","\n"," \n"," # !sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 0,@g' config.json\n"," # !sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","\n","# with open(os.path.join(pretrained_model_path, 'Quality Control', 'lr.csv'),'r') as csvfile:\n","# csvRead = pd.read_csv(csvfile, sep=',')\n","# #print(csvRead)\n"," \n","# if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n","# print(\"pretrained network learning rate found\")\n","# #find the last learning rate\n","# lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n","# #Find the learning rate corresponding to the lowest validation loss\n","# min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n","# #print(min_val_loss)\n","# bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n","# if Weights_choice == \"last\":\n","# print('Last learning rate: '+str(lastLearningRate))\n","\n","# if Weights_choice == \"best\":\n","# print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n","# if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n","# bestLearningRate = initial_learning_rate\n","# lastLearningRate = initial_learning_rate\n","# print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.1. Start training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"EZnoS3rb8BSR","cellView":"form"},"source":["#@markdown ##Start training\n","\n","full_model_path = os.path.join(model_path,model_name)\n","if os.path.exists(full_model_path):\n"," print(bcolors.WARNING+'Model folder already exists and has been overwritten.'+bcolors.NORMAL)\n"," shutil.rmtree(full_model_path)\n","\n","# Create a new directory\n","os.mkdir(full_model_path)\n","\n","# ------------\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","\n","#os.chdir('/content/drive/My Drive/Zero-Cost Deep-Learning to Enhance Microscopy/Various dataset/Detection_Dataset_2/BCCD.v2.voc')\n","#if not os.path.exists(model_path+'/full_raccoon.h5'):\n"," # !wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1NWbrpMGLc84ow-4gXn2mloFocFGU595s' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=1NWbrpMGLc84ow-4gXn2mloFocFGU595s\" -O full_yolo_raccoon.h5 && rm -rf /tmp/cookies.txt\n","\n","\n","full_model_file_path = full_model_path+'/best_weights.h5'\n","os.chdir('/content/gdrive/My Drive/keras-yolo2/')\n","\n","#Change backend name\n","!sed -i 's@\\\"backend\\\":.*,@\\\"backend\\\": \\\"$backend\\\",@g' config.json\n","\n","#Change the name of the training folder\n","!sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$Training_Source/\\\",@g' config.json\n","\n","#Change annotation folder\n","!sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$Training_Source_annotations/\\\",@g' config.json\n","\n","#Change the name of the saved model\n","!sed -i 's@\\\"saved_weights_name\\\":.*,@\\\"saved_weights_name\\\": \\\"$full_model_file_path\\\",@g' config.json\n","\n","#Change warmup epochs for untrained model\n","!sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 3,@g' config.json\n","\n","#When defining a new model we should reset the pretrained model parameter\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"No_pretrained_weights\\\",@g' config.json\n","\n","!sed -i 's@\\\"nb_epochs\\\":.*,@\\\"nb_epochs\\\": $number_of_epochs,@g' config.json\n","\n","!sed -i 's@\\\"train_times\\\":.*,@\\\"train_times\\\": $train_times,@g' config.json\n","!sed -i 's@\\\"batch_size\\\":.*,@\\\"batch_size\\\": $batch_size,@g' config.json\n","!sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","!sed -i 's@\\\"object_scale\":.*,@\\\"object_scale\\\": $false_negative_penalty,@g' config.json\n","!sed -i 's@\\\"no_object_scale\":.*,@\\\"no_object_scale\\\": $false_positive_penalty,@g' config.json\n","!sed -i 's@\\\"coord_scale\\\":.*,@\\\"coord_scale\\\": $position_size_penalty,@g' config.json\n","!sed -i 's@\\\"class_scale\\\":.*,@\\\"class_scale\\\": $false_class_penalty,@g' config.json\n","\n","#Write the annotations to a csv file\n","df_anno.to_csv(full_model_path+'/annotations.csv', index=False)#header=False, sep=',')\n","\n","!sed -i 's@\\\"labels\\\":.*@\\\"labels\\\": $class_labels@g' config.json\n","\n","\n","#Generate anchors for the bounding boxes\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","output = sp.getoutput('python ./gen_anchors.py -c ./config.json')\n","\n","anchors_1 = output.find(\"[\")\n","anchors_2 = output.find(\"]\")\n","\n","config_anchors = output[anchors_1:anchors_2+1]\n","!sed -i 's@\\\"anchors\\\":.*,@\\\"anchors\\\": $config_anchors,@g' config.json\n","\n","\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"$h5_file_path\\\",@g' config.json\n","\n","\n","# !sed -i 's@\\\"anchors\\\":.*,@\\\"anchors\\\": $config_anchors,@g' config.json\n","\n","\n","if Use_pretrained_model:\n"," !sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 0,@g' config.json\n"," !sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","\n","if Use_Data_augmentation:\n"," # os.chdir('/content/gdrive/My Drive/keras-yolo2')\n"," #Change the name of the training folder\n"," !sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$augmented_training_source/\\\",@g' config.json\n","\n"," #Change annotation folder\n"," !sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$augmented_training_source_annotation/\\\",@g' config.json\n","\n","\n","# ------------\n","\n","\n","\n","if os.path.exists(full_model_path+\"/Quality Control\"):\n"," shutil.rmtree(full_model_path+\"/Quality Control\")\n","os.makedirs(full_model_path+\"/Quality Control\")\n","\n","\n","start = time.time()\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","train('config.json', full_model_path, percentage_validation)\n","\n","shutil.copyfile('/content/gdrive/My Drive/keras-yolo2/config.json',full_model_path+'/config.json')\n","\n","if os.path.exists('/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5'):\n"," shutil.move('/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5',full_model_path+'/best_map_weights.h5')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","\n","# -------------------------------------------------------------\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'YOLOv2'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+'):\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and methods:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = sp.run('nvcc --version',stdout=sp.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = sp.run('nvidia-smi',stdout=sp.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_Source+'/'+os.listdir(Training_Source)[1]).shape\n","dataset_size = len(os.listdir(Training_Source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' labelled images (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and a custom loss function combining MSE and crossentropy losses, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' labelled images (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and a custom loss function combining MSE and crossentropy losses, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(multiply_dataset_by)+' by'\n"," if multiply_dataset_by >= 2:\n"," aug_text = aug_text+'\\n- flipping'\n"," if multiply_dataset_by > 2:\n"," aug_text = aug_text+'\\n- rotation'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
train_times{1}
batch_size{2}
learning_rate{3}
false_negative_penalty{4}
false_positive_penalty{5}
position_size_penalty{6}
false_class_penalty{7}
percentage_validation{8}
\n","\"\"\".format(number_of_epochs, train_times, batch_size, learning_rate, false_negative_penalty, false_positive_penalty, position_size_penalty, false_class_penalty, percentage_validation)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_Source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_Source_annotations, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","if visualise_example == True:\n"," pdf.cell(60, 5, txt = 'Example ground-truth annotation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_YOLOv2.png').shape\n"," pdf.image('/content/TrainingDataExample_YOLOv2.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- YOLOv2: Redmon, Joseph, and Ali Farhadi. \"YOLO9000: better, faster, stronger.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","ref_3 = '- YOLOv2 keras: https://github.com/experiencor/keras-yolo2, (2018)'\n","pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","if Use_Data_augmentation:\n"," ref_4 = '- imgaug: Jung, Alexander et al., https://github.com/aleju/imgaug, (2020)'\n"," pdf.multi_cell(190, 5, txt = ref_4, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n","print('------------------------------')\n","print('PDF report exported in '+model_path+'/'+model_name+'/')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku"},"source":["##**4.2. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_folder = full_model_path\n","\n","#print(os.path.join(model_path, model_name))\n","\n","QC_model_name = os.path.basename(QC_model_folder)\n","\n","if os.path.exists(QC_model_folder):\n"," print(\"The \"+QC_model_name+\" model will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/gdrive/My Drive/keras-yolo2/config.json'):\n"," os.remove('/content/gdrive/My Drive/keras-yolo2/config.json')\n"," shutil.copyfile(QC_model_folder+'/config.json','/content/gdrive/My Drive/keras-yolo2/config.json')\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","mAPDataFromCSV = []\n","with open(QC_model_folder+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n"," mAPDataFromCSV.append(float(row[2]))\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(20,15))\n","\n","plt.subplot(3,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(3,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","#plt.savefig(os.path.dirname(QC_model_folder)+'/Quality Control/lossCurvePlots.png')\n","#plt.show()\n","\n","plt.subplot(3,1,3)\n","plt.plot(epochNumber,mAPDataFromCSV, label='mAP score')\n","plt.title('mean average precision (mAP) vs. epoch number (linear scale)')\n","plt.ylabel('mAP score')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png',bbox_inches='tight', pad_inches=0)\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"RZOPCVN0qcYb"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display an overlay of the input images ground-truth (solid lines) and predicted boxes (dashed lines). Additionally, the below cell will show the mAP value of the model on the QC data together with plots of the Precision-Recall curves for all the classes in the dataset. If you want to read in more detail about these scores, we recommend [this brief explanation](https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173).\n","\n"," The images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" should contain images (e.g. as .jpg)and annotations (.xml files)!\n","\n","Since the training saves three different models, for the best validation loss (`best_weights`), best average precision (`best_mAP_weights`) and the model after the last epoch (`last_weights`), you should choose which ones you want to use for quality control or prediction. We recommend using `best_map_weights` because they should yield the best performance on the dataset. However, it can be worth testing how well `best_weights` perform too.\n","\n","**mAP score:** This refers to the mean average precision of the model on the given dataset. This value gives an indication how precise the predictions of the classes on this dataset are when compared to the ground-truth. Values closer to 1 indicate a good fit.\n","\n","**Precision:** This is the proportion of the correct classifications (true positives) in all the predictions made by the model.\n","\n","**Recall:** This is the proportion of the detected true positives in all the detectable data."]},{"cell_type":"code","metadata":{"id":"Nh8MlX3sqd_7","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Annotations_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#@markdown ##Choose which model you want to evaluate:\n","model_choice = \"best_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","file_suffix = os.path.splitext(os.listdir(Source_QC_folder)[0])[1]\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_folder+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","#Delete old csv with box predictions if one exists\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists(Source_QC_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Source_QC_folder+'/.ipynb_checkpoints')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","n_objects = []\n","for img in os.listdir(Source_QC_folder):\n"," full_image_path = Source_QC_folder+'/'+img\n"," print('----')\n"," print(img)\n"," n_obj = predict('config.json',QC_model_folder+'/'+model_choice+'.h5',full_image_path)\n"," n_objects.append(n_obj)\n"," K.clear_session()\n","\n","for img in os.listdir(Source_QC_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Source_QC_folder+'/'+img,QC_model_folder+\"/Quality Control/Prediction/\"+img)\n","\n","#Here, we open the config file to get the classes fro the GT labels\n","config_path = '/content/gdrive/My Drive/keras-yolo2/config.json'\n","with open(config_path) as config_buffer:\n"," config = json.load(config_buffer)\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(QC_model_folder+'/Quality Control/predicted_bounding_boxes_for_custom_ROI_QC.csv')\n","\n","F1_scores, AP, recall, precision = _calc_avg_precisions(config,Source_QC_folder,Annotations_QC_folder+'/',QC_model_folder+'/'+model_choice+'.h5',0.3,0.3)\n","\n","\n","\n","with open(QC_model_folder+\"/Quality Control/QC_results.csv\", \"r\") as file:\n"," x = from_csv(file)\n"," \n","print(x)\n","\n","mAP_score = sum(AP.values())/len(AP)\n","\n","print('mAP score for QC dataset: '+str(mAP_score))\n","\n","for i in range(len(AP)):\n"," if AP[i]!=0:\n"," fig = plt.figure(figsize=(8,4))\n"," if len(recall[i]) == 1:\n"," new_recall = np.linspace(0,list(recall[i])[0],10)\n"," new_precision = list(precision[i])*10\n"," fig = plt.figure(figsize=(3,2))\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', bbox_inches='tight', pad_inches=0)\n"," plt.show()\n"," else:\n"," new_recall = list(recall[i])\n"," new_recall.append(new_recall[len(new_recall)-1])\n"," new_precision = list(precision[i])\n"," new_precision.append(0)\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', bbox_inches='tight', pad_inches=0)\n"," plt.show()\n"," else:\n"," print('No object of class '+config['model']['labels'][i]+' was detected. This will lower the mAP score. Consider adding an image containing this class to your QC dataset to see if the model can detect this class at all.')\n","\n","\n","# --------------------------------------------------------------\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\n","\n","# This will display a randomly chosen dataset input and predicted output\n","\n","print('Below is an example input, prediction and ground truth annotation from your test dataset.')\n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","file_suffix = os.path.splitext(random_choice)[1]\n","\n","plt.figure(figsize=(30,15))\n","\n","### Display Raw input ###\n","\n","x = plt.imread(Source_QC_folder+\"/\"+random_choice)\n","plt.subplot(1,3,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest', cmap='gray')\n","plt.title('Input', fontsize = 12)\n","\n","### Display Predicted annotation ###\n","\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\n","for img in range(0,df_bbox2.shape[0]):\n"," df_bbox2.iloc[img]\n"," row = pd.DataFrame(df_bbox2.iloc[img])\n"," if row[img][0] == random_choice:\n"," row = row.dropna()\n"," image = imageio.imread(Source_QC_folder+'/'+row[img][0])\n"," #plt.figure(figsize=(12,12))\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," plt.imshow(image, cmap='gray') # plot image\n"," plt.title('Prediction', fontsize=12)\n"," for i in range(1,int(len(row)-1),6):\n"," plt_rectangle(plt,\n"," label = row[img][i+5],\n"," x1=row[img][i],#.format(iplot)],\n"," y1=row[img][i+1],\n"," x2=row[img][i+2],\n"," y2=row[img][i+3])#,\n"," #fontsize=8)\n","\n","\n","### Display GT Annotation ###\n","\n","df_anno_QC_gt = []\n","for fnm in os.listdir(Annotations_QC_folder): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(Annotations_QC_folder,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_QC_gt.append(row)\n","df_anno_QC_gt = pd.DataFrame(df_anno_QC_gt)\n","#maxNobj = np.max(df_anno_QC_gt[\"Nobj\"])\n","\n","for i in range(0,df_anno_QC_gt.shape[0]):\n"," if df_anno_QC_gt.iloc[i][\"fileID\"]+file_suffix == random_choice:\n"," row = df_anno_QC_gt.iloc[i]\n","\n","img = imageio.imread(Source_QC_folder+'/'+random_choice)\n","plt.subplot(1,3,3)\n","plt.axis('off')\n","plt.imshow(img, cmap='gray') # plot image\n","plt.title('Ground Truth annotations', fontsize=12)\n","\n","# for each object in the image, plot the bounding box\n","for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])#,\n"," #fontsize=8)\n","\n","### Show the plot ###\n","plt.savefig(QC_model_folder+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'YOLOv2'\n","\n","day = datetime.now()\n","datetime_str = str(day)[0:16]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate and Time: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","pdf.ln(1)\n","if os.path.exists(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png'):\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png').shape\n"," pdf.image(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(80, 5, txt = 'P-R curves for test dataset', ln=1, align='L')\n","pdf.ln(2)\n","for i in range(len(AP)):\n"," if os.path.exists(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png'):\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png').shape\n"," pdf.ln(1)\n"," pdf.image(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', x=16, y=None, w=round(exp_size[1]/4), h=round(exp_size[0]/4))\n"," else:\n"," pdf.cell(100, 5, txt='For the class '+config['model']['labels'][i]+' the model did not predict any objects.', ln=1, align='L')\n","pdf.ln(3)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(QC_model_folder+'/Quality Control/QC_results.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," class_name = header[0]\n"," fp = header[1]\n"," tp = header[2]\n"," fn = header[3]\n"," recall = header[4]\n"," precision = header[5]\n"," acc = header[6]\n"," f1 = header[7]\n"," AP_score = header[8]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(class_name,fp,tp,fn,recall,precision,acc,f1,AP_score)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," class_name = row[0]\n"," fp = row[1]\n"," tp = row[2]\n"," fn = row[3]\n"," recall = row[4]\n"," precision = row[5]\n"," acc = row[6]\n"," f1 = row[7]\n"," AP_score = row[8]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(class_name,fp,tp,fn,str(round(float(recall),3)),str(round(float(precision),3)),str(round(float(acc),3)),str(round(float(f1),3)),str(round(float(AP_score),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}{8}
{0}{1}{2}{3}{4}{5}{6}{7}{8}
\"\"\"\n","\n","pdf.write_html(html)\n","pdf.cell(180, 5, txt='Mean average precision (mAP) over the all classes is: '+str(round(mAP_score,3)), ln=1, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(3)\n","exp_size = io.imread(QC_model_folder+'/Quality Control/QC_example_data.png').shape\n","pdf.image(QC_model_folder+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.ln(3)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- YOLOv2: Redmon, Joseph, and Ali Farhadi. \"YOLO9000: better, faster, stronger.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","ref_3 = '- YOLOv2 keras: https://github.com/experiencor/keras-yolo2, (2018)'\n","pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(QC_model_folder+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","print('------------------------------')\n","print('PDF report exported in '+QC_model_folder+'/Quality Control/')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`Prediction_model_path`:** This should be the folder that contains your model."]},{"cell_type":"code","metadata":{"id":"9ZmST3JRq-Ho","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","file_suffix = os.path.splitext(os.listdir(Data_folder)[0])[1]\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model and path to model folder:\n","\n","Prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Which model do you want to use?\n","model_choice = \"best_map_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_path = full_model_path\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/gdrive/My Drive/keras-yolo2/config.json'):\n"," os.remove('/content/gdrive/My Drive/keras-yolo2/config.json')\n"," shutil.copyfile(Prediction_model_path+'/config.json','/content/gdrive/My Drive/keras-yolo2/config.json')\n","\n","if os.path.exists(Prediction_model_path+'/'+model_choice+'.h5'):\n"," print(\"The \"+os.path.basename(Prediction_model_path)+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Provide the code for performing predictions and saving them\n","print(\"Images will be saved into folder:\", Result_folder)\n","\n","\n","# ----- Predictions ------\n","\n","start = time.time()\n","\n","#Remove any files that might be from the prediction of QC examples.\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_new.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names_new.csv')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","if os.path.exists(Data_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Data_folder+'/.ipynb_checkpoints')\n","\n","n_objects = []\n","for img in os.listdir(Data_folder):\n"," full_image_path = Data_folder+'/'+img\n"," n_obj = predict('config.json',Prediction_model_path+'/'+model_choice+'.h5',full_image_path)#,Result_folder)\n"," n_objects.append(n_obj)\n"," K.clear_session()\n","for img in os.listdir(Data_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Data_folder+'/'+img,Result_folder+'/'+img)\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," #shutil.move('/content/predicted_bounding_boxes.csv',Result_folder+'/predicted_bounding_boxes.csv')\n"," print('Bounding box labels and coordinates saved to '+ Result_folder)\n","else:\n"," print('For some reason the bounding box labels and coordinates were not saved. Check that your predictions look as expected.')\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(Result_folder+'/predicted_bounding_boxes_for_custom_ROI.csv')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import random\n","from matplotlib.pyplot import imread\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","print(random_choice)\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+os.path.splitext(random_choice)[0]+'_detected'+file_suffix)\n","\n","plt.figure(figsize=(20,8))\n","\n","plt.subplot(1,3,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest', cmap='gray')\n","plt.title('Input')\n","\n","plt.subplot(1,3,2)\n","plt.axis('off')\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Predicted output');\n","\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\n","\n","#We need to edit this predicted_bounding_boxes_new.csv file slightly to display the bounding boxes\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\n","for img in range(0,df_bbox2.shape[0]):\n"," df_bbox2.iloc[img]\n"," row = pd.DataFrame(df_bbox2.iloc[img])\n"," if row[img][0] == random_choice:\n"," row = row.dropna()\n"," image = imageio.imread(Data_folder+'/'+row[img][0])\n"," #plt.figure(figsize=(12,12))\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," plt.title('Alternative Display of Prediction')\n"," plt.imshow(image, cmap='gray') # plot image\n","\n"," for i in range(1,int(len(row)-1),6):\n"," plt_rectangle(plt,\n"," label = row[img][i+5],\n"," x1=row[img][i],#.format(iplot)],\n"," y1=row[img][i+1],\n"," x2=row[img][i+2],\n"," y2=row[img][i+3])#,\n"," #fontsize=8)\n"," #plt.margins(0,0)\n"," #plt.subplots_adjust(left=0., right=1., top=1., bottom=0.)\n"," #plt.gca().xaxis.set_major_locator(plt.NullLocator())\n"," #plt.gca().yaxis.set_major_locator(plt.NullLocator())\n"," plt.savefig('/content/detected_cells.png',bbox_inches='tight',transparent=True,pad_inches=0)\n","plt.show() ## show the plot\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw"},"source":["\n","#**Thank you for using YOLOv2!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"YOLOv2_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1610968154980},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **YOLOv2**\n","---\n","\n"," YOLOv2 is a deep-learning method designed to perform object detection and classification of objects in images, published by [Redmon and Farhadi](https://ieeexplore.ieee.org/document/8100173). This is based on the original [YOLO](https://arxiv.org/abs/1506.02640) implementation published by the same authors. YOLOv2 is trained on images with class annotations in the form of bounding boxes drawn around the objects of interest. The images are downsampled by a convolutional neural network (CNN) and objects are classified in two final fully connected layers in the network. YOLOv2 learns classification and object detection simultaneously by taking the whole input image into account, predicting many possible bounding box solutions, and then using regression to find the best bounding boxes and classifications for each object.\n","\n","**This particular notebook enables object detection and classification on 2D images given ground truth bounding boxes. If you are interested in image segmentation, you should use our U-net or Stardist notebooks instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following papers: \n","\n","**YOLO9000: Better, Faster, Stronger** from Joseph Redmon and Ali Farhadi in Proceedings of the IEEE conference on computer vision and pattern recognition, 2017, (https://ieeexplore.ieee.org/document/8100173)\n","\n","**You Only Look Once: Unified, Real-Time Object Detection** from Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi in IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, (https://ieeexplore.ieee.org/document/7780460)\n","\n","**Note: The source code for this notebook is adapted for keras and can be found in: (https://github.com/experiencor/keras-yolo2)**\n","\n","\n","**Please also cite these original papers when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","\n","\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," Preparing the dataset carefully is essential to make this YOLOv2 notebook work. This model requires as input a set of images (currently .jpg) and as target a list of annotation files in Pascal VOC format. The annotation files should have the exact same name as the input files, except with an .xml instead of the .jpg extension. The annotation files contain the class labels and all bounding boxes for the objects for each image in your dataset. Most datasets will give the option of saving the annotations in this format or using software for hand-annotations will automatically save the annotations in this format. \n","\n"," If you want to assemble your own dataset we recommend using the open source https://www.makesense.ai/ resource. You can follow our instructions on how to label your dataset with this tool on our [wiki](https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki/Object-Detection-(YOLOv2)).\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .png or .jpg files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Input images (Training_source)\n"," - img_1.png, img_2.png, ...\n"," - High SNR images (Training_source_annotations)\n"," - img_1.xml, img_2.xml, ...\n"," - **Quality control dataset**\n"," - Input images\n"," - img_1.png, img_2.png\n"," - High SNR images\n"," - img_1.xml, img_2.xml\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install YOLOv2 and Dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","#@markdown ##Install Network and Dependencies\n","%tensorflow_version 1.x\n","!pip install pascal-voc-writer\n","!pip install fpdf\n","!pip install PTable\n","\n","from pascal_voc_writer import Writer\n","from __future__ import division\n","from __future__ import print_function\n","from __future__ import absolute_import\n","import csv\n","import random\n","import pprint\n","import sys\n","import time\n","import numpy as np\n","from optparse import OptionParser\n","import pickle\n","import math\n","import cv2\n","import copy\n","import math\n","from matplotlib import pyplot as plt\n","import matplotlib.patches as patches\n","import tensorflow as tf\n","import pandas as pd\n","import os\n","import shutil\n","from skimage import io\n","from sklearn.metrics import average_precision_score\n","\n","from keras.models import Model\n","from keras.layers import Flatten, Dense, Input, Conv2D, MaxPooling2D, Dropout, Reshape, Activation, Conv2D, MaxPooling2D, BatchNormalization, Lambda\n","from keras.layers.advanced_activations import LeakyReLU\n","from keras.layers.merge import concatenate\n","from keras.applications.mobilenet import MobileNet\n","from keras.applications import InceptionV3\n","from keras.applications.vgg16 import VGG16\n","from keras.applications.resnet50 import ResNet50\n","\n","from keras import backend as K\n","from keras.optimizers import Adam, SGD, RMSprop\n","from keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, TimeDistributed\n","from keras.engine.topology import get_source_inputs\n","from keras.utils import layer_utils\n","from keras.utils.data_utils import get_file\n","from keras.objectives import categorical_crossentropy\n","from keras.models import Model\n","from keras.utils import generic_utils\n","from keras.engine import Layer, InputSpec\n","from keras import initializers, regularizers\n","from keras.utils import Sequence\n","import xml.etree.ElementTree as ET\n","from collections import OrderedDict, Counter\n","import json\n","import imageio\n","import imgaug as ia\n","from imgaug import augmenters as iaa\n","import copy\n","import cv2\n","from tqdm import tqdm\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess as sp\n","\n","from prettytable import from_csv\n","\n","# from matplotlib.pyplot import imread\n","\n","ia.seed(1)\n","# imgaug uses matplotlib backend for displaying images\n","from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage\n","import re\n","import glob\n","\n","#Here, we import a different github repo which includes the map_evaluation.py\n","!git clone https://github.com/rodrigo2019/keras_yolo2.git\n","\n","if os.path.exists('/content/gdrive/My Drive/keras-yolo2'):\n"," shutil.rmtree('/content/gdrive/My Drive/keras-yolo2')\n","\n","#Here, we import the main github repo for this notebook and move it to the gdrive\n","!git clone https://github.com/experiencor/keras-yolo2.git\n","shutil.move('/content/keras-yolo2','/content/gdrive/My Drive/keras-yolo2')\n","#Now, we move the map_evaluation.py file to the main repo for this notebook.\n","#The source repo of the map_evaluation.py can then be ignored and is not further relevant for this notebook.\n","shutil.move('/content/keras_yolo2/keras_yolov2/map_evaluation.py','/content/gdrive/My Drive/keras-yolo2/map_evaluation.py')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","\n","from backend import BaseFeatureExtractor, FullYoloFeature\n","from preprocessing import parse_annotation, BatchGenerator\n","\n","\n","\n","def plt_rectangle(plt,label,x1,y1,x2,y2,fontsize=10):\n"," '''\n"," == Input ==\n"," \n"," plt : matplotlib.pyplot object\n"," label : string containing the object class name\n"," x1 : top left corner x coordinate\n"," y1 : top left corner y coordinate\n"," x2 : bottom right corner x coordinate\n"," y2 : bottom right corner y coordinate\n"," '''\n"," linewidth = 1\n"," color = \"yellow\"\n"," plt.text(x1,y1,label,fontsize=fontsize,backgroundcolor=\"magenta\")\n"," plt.plot([x1,x1],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x2,x2],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y1,y1], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y2,y2], linewidth=linewidth,color=color)\n","\n","def extract_single_xml_file(tree,object_count=True):\n"," Nobj = 0\n"," row = OrderedDict()\n"," for elems in tree.iter():\n","\n"," if elems.tag == \"size\":\n"," for elem in elems:\n"," row[elem.tag] = int(elem.text)\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"name\":\n"," row[\"bbx_{}_{}\".format(Nobj,elem.tag)] = str(elem.text) \n"," if elem.tag == \"bndbox\":\n"," for k in elem:\n"," row[\"bbx_{}_{}\".format(Nobj,k.tag)] = float(k.text)\n"," Nobj += 1\n"," if object_count == True:\n"," row[\"Nobj\"] = Nobj\n"," return(row)\n","\n","def count_objects(tree):\n"," Nobj=0\n"," for elems in tree.iter():\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"bndbox\":\n"," Nobj += 1\n"," return(Nobj)\n","\n","def compute_overlap(a, b):\n"," \"\"\"\n"," Code originally from https://github.com/rbgirshick/py-faster-rcnn.\n"," Parameters\n"," ----------\n"," a: (N, 4) ndarray of float\n"," b: (K, 4) ndarray of float\n"," Returns\n"," -------\n"," overlaps: (N, K) ndarray of overlap between boxes and query_boxes\n"," \"\"\"\n"," area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])\n","\n"," iw = np.minimum(np.expand_dims(a[:, 2], axis=1), b[:, 2]) - np.maximum(np.expand_dims(a[:, 0], 1), b[:, 0])\n"," ih = np.minimum(np.expand_dims(a[:, 3], axis=1), b[:, 3]) - np.maximum(np.expand_dims(a[:, 1], 1), b[:, 1])\n","\n"," iw = np.maximum(iw, 0)\n"," ih = np.maximum(ih, 0)\n","\n"," ua = np.expand_dims((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), axis=1) + area - iw * ih\n","\n"," ua = np.maximum(ua, np.finfo(float).eps)\n","\n"," intersection = iw * ih\n","\n"," return intersection / ua\n","\n","def compute_ap(recall, precision):\n"," \"\"\" Compute the average precision, given the recall and precision curves.\n"," Code originally from https://github.com/rbgirshick/py-faster-rcnn.\n","\n"," # Arguments\n"," recall: The recall curve (list).\n"," precision: The precision curve (list).\n"," # Returns\n"," The average precision as computed in py-faster-rcnn.\n"," \"\"\"\n"," # correct AP calculation\n"," # first append sentinel values at the end\n"," mrec = np.concatenate(([0.], recall, [1.]))\n"," mpre = np.concatenate(([0.], precision, [0.]))\n","\n"," # compute the precision envelope\n"," for i in range(mpre.size - 1, 0, -1):\n"," mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])\n","\n"," # to calculate area under PR curve, look for points\n"," # where X axis (recall) changes value\n"," i = np.where(mrec[1:] != mrec[:-1])[0]\n","\n"," # and sum (\\Delta recall) * prec\n"," ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])\n"," return ap \n","\n","def load_annotation(image_folder,annotations_folder, i, config):\n"," annots = []\n"," imgs, anns = parse_annotation(annotations_folder,image_folder,config['model']['labels'])\n"," for obj in imgs[i]['object']:\n"," annot = [obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax'], config['model']['labels'].index(obj['name'])]\n"," annots += [annot]\n","\n"," if len(annots) == 0: annots = [[]]\n","\n"," return np.array(annots)\n","\n","def _calc_avg_precisions(config,image_folder,annotations_folder,weights_path,iou_threshold,score_threshold):\n","\n"," # gather all detections and annotations\n"," all_detections = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(image_folder)))]\n"," all_annotations = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(annotations_folder)))]\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," raw_image = cv2.imread(os.path.join(image_folder,sorted(os.listdir(image_folder))[i]))\n"," raw_height, raw_width, _ = raw_image.shape\n"," #print(raw_height)\n"," # make the boxes and the labels\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n"," yolo.load_weights(weights_path)\n"," pred_boxes = yolo.predict(raw_image,iou_threshold=iou_threshold,score_threshold=score_threshold)\n","\n"," score = np.array([box.score for box in pred_boxes])\n"," #print(score)\n"," pred_labels = np.array([box.label for box in pred_boxes])\n"," #print(len(pred_boxes))\n"," if len(pred_boxes) > 0:\n"," pred_boxes = np.array([[box.xmin * raw_width, box.ymin * raw_height, box.xmax * raw_width,\n"," box.ymax * raw_height, box.score] for box in pred_boxes])\n"," else:\n"," pred_boxes = np.array([[]])\n","\n"," # sort the boxes and the labels according to scores\n"," score_sort = np.argsort(-score)\n"," pred_labels = pred_labels[score_sort]\n"," pred_boxes = pred_boxes[score_sort]\n","\n"," # copy detections to all_detections\n"," for label in range(len(config['model']['labels'])):\n"," all_detections[i][label] = pred_boxes[pred_labels == label, :]\n","\n"," annotations = load_annotation(image_folder,annotations_folder,i,config)\n","\n"," # copy ground truth to all_annotations\n"," for label in range(len(config['model']['labels'])):\n"," all_annotations[i][label] = annotations[annotations[:, 4] == label, :4].copy()\n","\n"," # compute mAP by comparing all detections and all annotations\n"," average_precisions = {}\n"," F1_scores = {}\n"," total_recall = []\n"," total_precision = []\n"," \n"," with open(QC_model_folder+\"/Quality Control/QC_results.csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"class\", \"false positive\", \"true positive\", \"false negative\", \"recall\", \"precision\", \"accuracy\", \"f1 score\", \"average_precision\"]) \n"," \n"," for label in range(len(config['model']['labels'])):\n"," false_positives = np.zeros((0,))\n"," true_positives = np.zeros((0,))\n"," scores = np.zeros((0,))\n"," num_annotations = 0.0\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," detections = all_detections[i][label]\n"," annotations = all_annotations[i][label]\n"," num_annotations += annotations.shape[0]\n"," detected_annotations = []\n","\n"," for d in detections:\n"," scores = np.append(scores, d[4])\n","\n"," if annotations.shape[0] == 0:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n"," continue\n","\n"," overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations)\n"," assigned_annotation = np.argmax(overlaps, axis=1)\n"," max_overlap = overlaps[0, assigned_annotation]\n","\n"," if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations:\n"," false_positives = np.append(false_positives, 0)\n"," true_positives = np.append(true_positives, 1)\n"," detected_annotations.append(assigned_annotation)\n"," else:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n","\n"," # no annotations -> AP for this class is 0 (is this correct?)\n"," if num_annotations == 0:\n"," average_precisions[label] = 0\n"," continue\n","\n"," # sort by score\n"," indices = np.argsort(-scores)\n"," false_positives = false_positives[indices]\n"," true_positives = true_positives[indices]\n","\n"," # compute false positives and true positives\n"," false_positives = np.cumsum(false_positives)\n"," true_positives = np.cumsum(true_positives)\n","\n"," # compute recall and precision\n"," recall = true_positives / num_annotations\n"," precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)\n"," total_recall.append(recall)\n"," total_precision.append(precision)\n"," #print(precision)\n"," # compute average precision\n"," average_precision = compute_ap(recall, precision)\n"," average_precisions[label] = average_precision\n","\n"," if len(precision) != 0:\n"," F1_score = 2*(precision[-1]*recall[-1]/(precision[-1]+recall[-1]))\n"," F1_scores[label] = F1_score\n"," writer.writerow([config['model']['labels'][label], str(int(false_positives[-1])), str(int(true_positives[-1])), str(int(num_annotations-true_positives[-1])), str(recall[-1]), str(precision[-1]), str(true_positives[-1]/num_annotations), str(F1_scores[label]), str(average_precisions[label])])\n"," else:\n"," F1_score = 0\n"," F1_scores[label] = F1_score\n"," writer.writerow([config['model']['labels'][label], str(0), str(0), str(0), str(0), str(0), str(0), str(F1_score), str(average_precisions[label])])\n"," return F1_scores, average_precisions, total_recall, total_precision\n","\n","\n","def show_frame(pred_bb, pred_classes, pred_conf, gt_bb, gt_classes, class_dict, background=np.zeros((512, 512, 3)), show_confidence=True):\n"," \"\"\"\n"," Here, we are adapting classes and functions from https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," \"\"\"\n"," Plot the boundingboxes\n"," :param pred_bb: (np.array) Predicted Bounding Boxes [x1, y1, x2, y2] : Shape [n_pred, 4]\n"," :param pred_classes: (np.array) Predicted Classes : Shape [n_pred]\n"," :param pred_conf: (np.array) Predicted Confidences [0.-1.] : Shape [n_pred]\n"," :param gt_bb: (np.array) Ground Truth Bounding Boxes [x1, y1, x2, y2] : Shape [n_gt, 4]\n"," :param gt_classes: (np.array) Ground Truth Classes : Shape [n_gt]\n"," :param class_dict: (dictionary) Key value pairs of classes, e.g. {0:'dog',1:'cat',2:'horse'}\n"," :return:\n"," \"\"\"\n"," n_pred = pred_bb.shape[0]\n"," n_gt = gt_bb.shape[0]\n"," n_class = int(np.max(np.append(pred_classes, gt_classes)) + 1)\n"," #print(n_class)\n"," if len(background.shape) < 3:\n"," h, w = background.shape\n"," else:\n"," h, w, c = background.shape\n","\n"," ax = plt.subplot(\"111\")\n"," ax.imshow(background)\n"," cmap = plt.cm.get_cmap('hsv')\n","\n"," confidence_alpha = pred_conf.copy()\n"," if not show_confidence:\n"," confidence_alpha.fill(1)\n","\n"," for i in range(n_pred):\n"," x1 = pred_bb[i, 0]# * w\n"," y1 = pred_bb[i, 1]# * h\n"," x2 = pred_bb[i, 2]# * w\n"," y2 = pred_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," #print(x1, y1)\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(pred_classes[i]) / n_class),\n"," linestyle='dashdot',\n"," alpha=confidence_alpha[i]))\n","\n"," for i in range(n_gt):\n"," x1 = gt_bb[i, 0]# * w\n"," y1 = gt_bb[i, 1]# * h\n"," x2 = gt_bb[i, 2]# * w\n"," y2 = gt_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(gt_classes[i]) / n_class)))\n","\n"," legend_handles = []\n","\n"," for i in range(n_class):\n"," legend_handles.append(patches.Patch(color=cmap(float(i) / n_class), label=class_dict[i]))\n"," \n"," ax.legend(handles=legend_handles)\n"," plt.show()\n","\n","class BoundBox:\n"," \"\"\"\n"," Here, we are adapting classes and functions from https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," def __init__(self, xmin, ymin, xmax, ymax, c = None, classes = None):\n"," self.xmin = xmin\n"," self.ymin = ymin\n"," self.xmax = xmax\n"," self.ymax = ymax\n"," \n"," self.c = c\n"," self.classes = classes\n","\n"," self.label = -1\n"," self.score = -1\n","\n"," def get_label(self):\n"," if self.label == -1:\n"," self.label = np.argmax(self.classes)\n"," \n"," return self.label\n"," \n"," def get_score(self):\n"," if self.score == -1:\n"," self.score = self.classes[self.get_label()]\n"," \n"," return self.score\n","\n","class WeightReader:\n"," def __init__(self, weight_file):\n"," self.offset = 4\n"," self.all_weights = np.fromfile(weight_file, dtype='float32')\n"," \n"," def read_bytes(self, size):\n"," self.offset = self.offset + size\n"," return self.all_weights[self.offset-size:self.offset]\n"," \n"," def reset(self):\n"," self.offset = 4\n","\n","def bbox_iou(box1, box2):\n"," intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])\n"," intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax]) \n"," \n"," intersect = intersect_w * intersect_h\n","\n"," w1, h1 = box1.xmax-box1.xmin, box1.ymax-box1.ymin\n"," w2, h2 = box2.xmax-box2.xmin, box2.ymax-box2.ymin\n"," \n"," union = w1*h1 + w2*h2 - intersect\n"," \n"," return float(intersect) / union\n","\n","def draw_boxes(image, boxes, labels):\n"," image_h, image_w, _ = image.shape\n"," #Changes in box color added by LvC\n"," # class_colours = []\n"," # for c in range(len(labels)):\n"," # colour = np.random.randint(low=0,high=255,size=3).tolist()\n"," # class_colours.append(tuple(colour))\n"," for box in boxes:\n"," xmin = int(box.xmin*image_w)\n"," ymin = int(box.ymin*image_h)\n"," xmax = int(box.xmax*image_w)\n"," ymax = int(box.ymax*image_h)\n"," if box.get_label() == 0:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (255,0,0), 3)\n"," elif box.get_label() == 1:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,255,0), 3)\n"," else:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,0,255), 3)\n"," #cv2.rectangle(image, (xmin,ymin), (xmax,ymax), class_colours[box.get_label()], 3)\n"," cv2.putText(image, \n"," labels[box.get_label()] + ' ' + str(round(box.get_score(),3)), \n"," (xmin, ymin - 13), \n"," cv2.FONT_HERSHEY_SIMPLEX, \n"," 1e-3 * image_h, \n"," (0,0,0), 2)\n"," #print(box.get_label()) \n"," return image \n","\n","#Function added by LvC\n","def save_boxes(image_path, boxes, labels):#, save_path):\n"," image = cv2.imread(image_path)\n"," image_h, image_w, _ = image.shape\n"," save_boxes =[]\n"," save_boxes_names = []\n"," save_boxes.append(os.path.basename(image_path))\n"," save_boxes_names.append(os.path.basename(image_path))\n"," for box in boxes:\n"," # xmin = box.xmin\n"," save_boxes.append(int(box.xmin*image_w))\n"," save_boxes_names.append(int(box.xmin*image_w))\n"," # ymin = box.ymin\n"," save_boxes.append(int(box.ymin*image_h))\n"," save_boxes_names.append(int(box.ymin*image_h))\n"," # xmax = box.xmax\n"," save_boxes.append(int(box.xmax*image_w))\n"," save_boxes_names.append(int(box.xmax*image_w))\n"," # ymax = box.ymax\n"," save_boxes.append(int(box.ymax*image_h))\n"," save_boxes_names.append(int(box.ymax*image_h))\n"," score = box.get_score()\n"," save_boxes.append(score)\n"," save_boxes_names.append(score)\n"," label = box.get_label()\n"," save_boxes.append(label)\n"," save_boxes_names.append(labels[label])\n"," \n"," #This file will be for later analysis of the bounding boxes in imagej\n"," if not os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," with open('/content/predicted_bounding_boxes.csv', 'w', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes)\n"," else:\n"," with open('/content/predicted_bounding_boxes.csv', 'a+', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile)\n"," csvwriter.writerow(save_boxes)\n"," \n"," if not os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," with open('/content/predicted_bounding_boxes_names.csv', 'w', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes_names)\n"," else:\n"," with open('/content/predicted_bounding_boxes_names.csv', 'a+', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names)\n"," csvwriter.writerow(save_boxes_names)\n"," # #This file is to create a nicer display for the output images\n"," # if not os.path.exists('/content/predicted_bounding_boxes_display.csv'):\n"," # with open('/content/predicted_bounding_boxes_display.csv', 'w', newline='') as csvfile_new:\n"," # csvwriter2 = csv.writer(csvfile_new, delimiter=',')\n"," # specs_list = ['filename','width','height','class','xmin','ymin','xmax','ymax']\n"," # csvwriter2.writerow(specs_list)\n"," # else:\n"," # with open('/content/predicted_bounding_boxes_display.csv','a+',newline='') as csvfile_new:\n"," # csvwriter2 = csv.writer(csvfile_new)\n"," # for box in boxes:\n"," # row = [os.path.basename(image_path),image_w,image_h,box.get_label(),int(box.xmin*image_w),int(box.ymin*image_h),int(box.xmax*image_w),int(box.ymax*image_h)]\n"," # csvwriter2.writerow(row)\n","\n","def add_header(inFilePath,outFilePath):\n"," header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n"," with open(inFilePath, newline='') as inFile, open(outFilePath, 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n"," \n","def decode_netout(netout, anchors, nb_class, obj_threshold=0.3, nms_threshold=0.5):\n"," grid_h, grid_w, nb_box = netout.shape[:3]\n","\n"," boxes = []\n"," \n"," # decode the output by the network\n"," netout[..., 4] = _sigmoid(netout[..., 4])\n"," netout[..., 5:] = netout[..., 4][..., np.newaxis] * _softmax(netout[..., 5:])\n"," netout[..., 5:] *= netout[..., 5:] > obj_threshold\n"," \n"," for row in range(grid_h):\n"," for col in range(grid_w):\n"," for b in range(nb_box):\n"," # from 4th element onwards are confidence and class classes\n"," classes = netout[row,col,b,5:]\n"," \n"," if np.sum(classes) > 0:\n"," # first 4 elements are x, y, w, and h\n"," x, y, w, h = netout[row,col,b,:4]\n","\n"," x = (col + _sigmoid(x)) / grid_w # center position, unit: image width\n"," y = (row + _sigmoid(y)) / grid_h # center position, unit: image height\n"," w = anchors[2 * b + 0] * np.exp(w) / grid_w # unit: image width\n"," h = anchors[2 * b + 1] * np.exp(h) / grid_h # unit: image height\n"," confidence = netout[row,col,b,4]\n"," \n"," box = BoundBox(x-w/2, y-h/2, x+w/2, y+h/2, confidence, classes)\n"," \n"," boxes.append(box)\n","\n"," # suppress non-maximal boxes\n"," for c in range(nb_class):\n"," sorted_indices = list(reversed(np.argsort([box.classes[c] for box in boxes])))\n","\n"," for i in range(len(sorted_indices)):\n"," index_i = sorted_indices[i]\n"," \n"," if boxes[index_i].classes[c] == 0: \n"," continue\n"," else:\n"," for j in range(i+1, len(sorted_indices)):\n"," index_j = sorted_indices[j]\n"," \n"," if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_threshold:\n"," boxes[index_j].classes[c] = 0\n"," \n"," # remove the boxes which are less likely than a obj_threshold\n"," boxes = [box for box in boxes if box.get_score() > obj_threshold]\n"," \n"," return boxes\n","\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","with open(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," reduce_lr = False\n"," for line in lineReader:\n"," if \"reduce_lr\" in line:\n"," reduce_lr = True\n"," break\n","\n","if reduce_lr == False:\n"," #replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n csv_logger=CSVLogger('/content/training_evaluation.csv')\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n reduce_lr=ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1)\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"import EarlyStopping\",\"import ReduceLROnPlateau, EarlyStopping\")\n","\n","with open(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," map_eval = False\n"," for line in lineReader:\n"," if \"map_evaluation\" in line:\n"," map_eval = True\n"," break\n","\n","if map_eval == False:\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"import cv2\",\"import cv2\\nfrom map_evaluation import MapEvaluation\")\n"," new_callback = ' map_evaluator = MapEvaluation(self, valid_generator,save_best=True,save_name=\"/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5\",iou_threshold=0.3,score_threshold=0.3)'\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"write_images=False)\",\"write_images=False)\\n\"+new_callback)\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"import keras\",\"import keras\\nimport csv\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"from .utils\",\"from utils\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\".format(_map))\",\".format(_map))\\n with open('/content/gdrive/My Drive/mAP.csv','a+', newline='') as mAP_csv:\\n csv_writer=csv.writer(mAP_csv)\\n csv_writer.writerow(['mAP:','{:.4f}'.format(_map)])\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"iou_threshold=0.5\",\"iou_threshold=0.3\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"score_threshold=0.5\",\"score_threshold=0.3\")\n","\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"[early_stop, checkpoint, tensorboard]\",\"[checkpoint, reduce_lr, map_evaluator]\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"predict(self, image)\",\"predict(self,image,iou_threshold=0.3,score_threshold=0.3)\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"self.model.summary()\",\"#self.model.summary()\")\n","from frontend import YOLO\n","\n","def train(config_path, model_path, percentage_validation):\n"," #config_path = args.conf\n","\n"," with open(config_path) as config_buffer: \n"," config = json.loads(config_buffer.read())\n","\n"," ###############################\n"," # Parse the annotations \n"," ###############################\n","\n"," # parse annotations of the training set\n"," train_imgs, train_labels = parse_annotation(config['train']['train_annot_folder'], \n"," config['train']['train_image_folder'], \n"," config['model']['labels'])\n","\n"," # parse annotations of the validation set, if any, otherwise split the training set\n"," if os.path.exists(config['valid']['valid_annot_folder']):\n"," valid_imgs, valid_labels = parse_annotation(config['valid']['valid_annot_folder'], \n"," config['valid']['valid_image_folder'], \n"," config['model']['labels'])\n"," else:\n"," train_valid_split = int((1-percentage_validation/100.)*len(train_imgs))\n"," np.random.shuffle(train_imgs)\n","\n"," valid_imgs = train_imgs[train_valid_split:]\n"," train_imgs = train_imgs[:train_valid_split]\n","\n"," if len(config['model']['labels']) > 0:\n"," overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys()))\n","\n"," print('Seen labels:\\t', train_labels)\n"," print('Given labels:\\t', config['model']['labels'])\n"," print('Overlap labels:\\t', overlap_labels) \n","\n"," if len(overlap_labels) < len(config['model']['labels']):\n"," print('Some labels have no annotations! Please revise the list of labels in the config.json file!')\n"," return\n"," else:\n"," print('No labels are provided. Train on all seen labels.')\n"," config['model']['labels'] = train_labels.keys()\n"," \n"," ###############################\n"," # Construct the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load the pretrained weights (if any) \n"," ############################### \n","\n"," if os.path.exists(config['train']['pretrained_weights']):\n"," print(\"Loading pre-trained weights in\", config['train']['pretrained_weights'])\n"," yolo.load_weights(config['train']['pretrained_weights'])\n"," if os.path.exists('/content/gdrive/My Drive/mAP.csv'):\n"," os.remove('/content/gdrive/My Drive/mAP.csv')\n"," ###############################\n"," # Start the training process \n"," ###############################\n","\n"," yolo.train(train_imgs = train_imgs,\n"," valid_imgs = valid_imgs,\n"," train_times = config['train']['train_times'],\n"," valid_times = config['valid']['valid_times'],\n"," nb_epochs = config['train']['nb_epochs'], \n"," learning_rate = config['train']['learning_rate'], \n"," batch_size = config['train']['batch_size'],\n"," warmup_epochs = config['train']['warmup_epochs'],\n"," object_scale = config['train']['object_scale'],\n"," no_object_scale = config['train']['no_object_scale'],\n"," coord_scale = config['train']['coord_scale'],\n"," class_scale = config['train']['class_scale'],\n"," saved_weights_name = config['train']['saved_weights_name'],\n"," debug = config['train']['debug'])\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(model_path,'Quality Control/training_evaluation.csv')\n"," with open(lossDataCSVpath, 'w') as f1:\n"," writer = csv.writer(f1)\n"," mAP_df = pd.read_csv('/content/gdrive/My Drive/mAP.csv',header=None)\n"," writer.writerow(['loss','val_loss','mAP','learning rate'])\n"," for i in range(len(yolo.model.history.history['loss'])):\n"," writer.writerow([yolo.model.history.history['loss'][i], yolo.model.history.history['val_loss'][i], float(mAP_df[1][i]), yolo.model.history.history['lr'][i]])\n","\n"," yolo.model.save(model_path+'/last_weights.h5')\n","\n","def predict(config, weights_path, image_path):#, model_path):\n","\n"," with open(config) as config_buffer: \n"," config = json.load(config_buffer)\n","\n"," ###############################\n"," # Make the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load trained weights\n"," ############################### \n","\n"," yolo.load_weights(weights_path)\n","\n"," ###############################\n"," # Predict bounding boxes \n"," ###############################\n","\n"," if image_path[-4:] == '.mp4':\n"," video_out = image_path[:-4] + '_detected' + image_path[-4:]\n"," video_reader = cv2.VideoCapture(image_path)\n","\n"," nb_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))\n"," frame_h = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))\n"," frame_w = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))\n","\n"," video_writer = cv2.VideoWriter(video_out,\n"," cv2.VideoWriter_fourcc(*'MPEG'), \n"," 50.0, \n"," (frame_w, frame_h))\n","\n"," for i in tqdm(range(nb_frames)):\n"," _, image = video_reader.read()\n"," \n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n","\n"," video_writer.write(np.uint8(image))\n","\n"," video_reader.release()\n"," video_writer.release() \n"," else:\n"," image = cv2.imread(image_path)\n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n"," save_boxes(image_path,boxes,config['model']['labels'])#,model_path)#added by LvC\n"," print(len(boxes), 'boxes are found')\n"," #print(image)\n"," cv2.imwrite(image_path[:-4] + '_detected' + image_path[-4:], image)\n"," \n"," return len(boxes)\n","\n","# function to convert BoundingBoxesOnImage object into DataFrame\n","def bbs_obj_to_df(bbs_object):\n","# convert BoundingBoxesOnImage object into array\n"," bbs_array = bbs_object.to_xyxy_array()\n","# convert array into a DataFrame ['xmin', 'ymin', 'xmax', 'ymax'] columns\n"," df_bbs = pd.DataFrame(bbs_array, columns=['xmin', 'ymin', 'xmax', 'ymax'])\n"," return df_bbs\n","\n","# Function that will extract column data for our CSV file\n","def xml_to_csv(path):\n"," xml_list = []\n"," for xml_file in glob.glob(path + '/*.xml'):\n"," tree = ET.parse(xml_file)\n"," root = tree.getroot()\n"," for member in root.findall('object'):\n"," value = (root.find('filename').text,\n"," int(root.find('size')[0].text),\n"," int(root.find('size')[1].text),\n"," member[0].text,\n"," int(member[4][0].text),\n"," int(member[4][1].text),\n"," int(member[4][2].text),\n"," int(member[4][3].text)\n"," )\n"," xml_list.append(value)\n"," column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," xml_df = pd.DataFrame(xml_list, columns=column_name)\n"," return xml_df\n","\n","\n","\n","def image_aug(df, images_path, aug_images_path, image_prefix, augmentor):\n"," # create data frame which we're going to populate with augmented image info\n"," aug_bbs_xy = pd.DataFrame(columns=\n"," ['filename','width','height','class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," )\n"," grouped = df.groupby('filename')\n"," \n"," for filename in df['filename'].unique():\n"," # get separate data frame grouped by file name\n"," group_df = grouped.get_group(filename)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1) \n"," # read the image\n"," image = imageio.imread(images_path+filename)\n"," # get bounding boxes coordinates and write into array \n"," bb_array = group_df.drop(['filename', 'width', 'height', 'class'], axis=1).values\n"," # pass the array of bounding boxes coordinates to the imgaug library\n"," bbs = BoundingBoxesOnImage.from_xyxy_array(bb_array, shape=image.shape)\n"," # apply augmentation on image and on the bounding boxes\n"," image_aug, bbs_aug = augmentor(image=image, bounding_boxes=bbs)\n"," # disregard bounding boxes which have fallen out of image pane \n"," bbs_aug = bbs_aug.remove_out_of_image()\n"," # clip bounding boxes which are partially outside of image pane\n"," bbs_aug = bbs_aug.clip_out_of_image()\n"," \n"," # don't perform any actions with the image if there are no bounding boxes left in it \n"," if re.findall('Image...', str(bbs_aug)) == ['Image([]']:\n"," pass\n"," \n"," # otherwise continue\n"," else:\n"," # write augmented image to a file\n"," imageio.imwrite(aug_images_path+image_prefix+filename, image_aug) \n"," # create a data frame with augmented values of image width and height\n"," info_df = group_df.drop(['xmin', 'ymin', 'xmax', 'ymax'], axis=1) \n"," for index, _ in info_df.iterrows():\n"," info_df.at[index, 'width'] = image_aug.shape[1]\n"," info_df.at[index, 'height'] = image_aug.shape[0]\n"," # rename filenames by adding the predifined prefix\n"," info_df['filename'] = info_df['filename'].apply(lambda x: image_prefix+x)\n"," # create a data frame with augmented bounding boxes coordinates using the function we created earlier\n"," bbs_df = bbs_obj_to_df(bbs_aug)\n"," # concat all new augmented info into new data frame\n"," aug_df = pd.concat([info_df, bbs_df], axis=1)\n"," # append rows to aug_bbs_xy data frame\n"," aug_bbs_xy = pd.concat([aug_bbs_xy, aug_df]) \n"," \n"," # return dataframe with updated images and bounding boxes annotations \n"," aug_bbs_xy = aug_bbs_xy.reset_index()\n"," aug_bbs_xy = aug_bbs_xy.drop(['index'], axis=1)\n"," return aug_bbs_xy\n","\n","\n","print('-------------------------------------------')\n","print(\"Depencies installed and imported.\")\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m'\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\"+bcolors.NORMAL)\n","\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'YOLOv2'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+'):\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell\n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = sp.run('nvcc --version',stdout=sp.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = sp.run('nvidia-smi',stdout=sp.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_Source+'/'+os.listdir(Training_Source)[1]).shape\n"," dataset_size = len(os.listdir(Training_Source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' labelled images (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and a custom loss function combining MSE and crossentropy losses, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' labelled images (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and a custom loss function combining MSE and crossentropy losses, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(multiply_dataset_by)+' by'\n"," if multiply_dataset_by >= 2:\n"," aug_text = aug_text+'\\n- flipping'\n"," if multiply_dataset_by > 2:\n"," aug_text = aug_text+'\\n- rotation'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
train_times{1}
batch_size{2}
learning_rate{3}
false_negative_penalty{4}
false_positive_penalty{5}
position_size_penalty{6}
false_class_penalty{7}
percentage_validation{8}
\n"," \"\"\".format(number_of_epochs, train_times, batch_size, learning_rate, false_negative_penalty, false_positive_penalty, position_size_penalty, false_class_penalty, percentage_validation)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_Source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_Source_annotations, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," if visualise_example == True:\n"," pdf.cell(60, 5, txt = 'Example ground-truth annotation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_YOLOv2.png').shape\n"," pdf.image('/content/TrainingDataExample_YOLOv2.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- YOLOv2: Redmon, Joseph, and Ali Farhadi. \"YOLO9000: better, faster, stronger.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," ref_3 = '- YOLOv2 keras: https://github.com/experiencor/keras-yolo2, (2018)'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," if augmentation:\n"," ref_4 = '- imgaug: Jung, Alexander et al., https://github.com/aleju/imgaug, (2020)'\n"," pdf.multi_cell(190, 5, txt = ref_4, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'YOLOv2'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:16]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate and Time: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," if os.path.exists(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png'):\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png').shape\n"," pdf.image(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(80, 5, txt = 'P-R curves for test dataset', ln=1, align='L')\n"," pdf.ln(2)\n"," for i in range(len(AP)):\n"," if os.path.exists(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png'):\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png').shape\n"," pdf.ln(1)\n"," pdf.image(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', x=16, y=None, w=round(exp_size[1]/4), h=round(exp_size[0]/4))\n"," else:\n"," pdf.cell(100, 5, txt='For the class '+config['model']['labels'][i]+' the model did not predict any objects.', ln=1, align='L')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(QC_model_folder+'/Quality Control/QC_results.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," class_name = header[0]\n"," fp = header[1]\n"," tp = header[2]\n"," fn = header[3]\n"," recall = header[4]\n"," precision = header[5]\n"," acc = header[6]\n"," f1 = header[7]\n"," AP_score = header[8]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(class_name,fp,tp,fn,recall,precision,acc,f1,AP_score)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," class_name = row[0]\n"," fp = row[1]\n"," tp = row[2]\n"," fn = row[3]\n"," recall = row[4]\n"," precision = row[5]\n"," acc = row[6]\n"," f1 = row[7]\n"," AP_score = row[8]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(class_name,fp,tp,fn,str(round(float(recall),3)),str(round(float(precision),3)),str(round(float(acc),3)),str(round(float(f1),3)),str(round(float(AP_score),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}{8}
{0}{1}{2}{3}{4}{5}{6}{7}{8}
\"\"\"\n","\n"," pdf.write_html(html)\n"," pdf.cell(180, 5, txt='Mean average precision (mAP) over the all classes is: '+str(round(mAP_score,3)), ln=1, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(3)\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(QC_model_folder+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.ln(3)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- YOLOv2: Redmon, Joseph, and Ali Farhadi. \"YOLO9000: better, faster, stronger.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," ref_3 = '- YOLOv2 keras: https://github.com/experiencor/keras-yolo2, (2018)'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(QC_model_folder+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+QC_model_folder+'/Quality Control/')\n","# Exporting requirements.txt for local run\n","!pip freeze > requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your paths and parameters**\n","\n","---\n","\n","The code below allows the user to enter the paths to where the training data is and to define the training parameters.\n","\n","After playing the cell will display some quantitative metrics of your dataset, including a count of objects per image and the number of instances per class.\n"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_source_annotations`:** These are the paths to your folders containing the Training_source and the annotation data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Give estimates for training performance given a number of epochs and provide a default value. **Default value: 27**\n","\n","**Note that YOLOv2 uses 3 Warm-up epochs which improves the model's performance. This means the network will train for number_of_epochs + 3 epochs.**\n","\n","**`backend`:** There are different backends which are available to be trained for YOLO. These are usually slightly different model architectures, with pretrained weights. Take a look at the available backends and research which one will be best suited for your dataset.\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`train_times:`**Input how many times to cycle through the dataset per epoch. This is more useful for smaller datasets (but risks overfitting). **Default value: 4**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`learning_rate:`** Input the initial value to be used as learning rate. **Default value: 0.0004**\n","\n","**`false_negative_penalty:`** Penalize wrong detection of 'no-object'. **Default: 5.0**\n","\n","**`false_positive_penalty:`** Penalize wrong detection of 'object'. **Default: 1.0**\n","\n","**`position_size_penalty:`** Penalize inaccurate positioning or size of bounding boxes. **Default:1.0**\n","\n","**`false_class_penalty:`** Penalize misclassification of object in bounding box. **Default: 1.0**\n","\n","**`percentage_validation:`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** "]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images:\n","\n","Training_Source = \"\" #@param {type:\"string\"}\n","\n","# Ground truth images\n","Training_Source_annotations = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# backend\n","#@markdown ###Choose a backend\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","\n","\n","full_model_path = os.path.join(model_path,model_name)\n","if os.path.exists(full_model_path):\n"," print(bcolors.WARNING+'Model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","\n","# other parameters for training.\n","# @markdown ###Training Parameters\n","# @markdown Number of epochs:\n","\n","number_of_epochs = 27#@param {type:\"number\"}\n","\n","# !sed -i 's@\\\"nb_epochs\\\":.*,@\\\"nb_epochs\\\": $number_of_epochs,@g' config.json\n","\n","# #@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","train_times = 4 #@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"number\"}\n","learning_rate = 1e-4 #@param{type:\"number\"}\n","false_negative_penalty = 5.0 #@param{type:\"number\"}\n","false_positive_penalty = 1.0 #@param{type:\"number\"}\n","position_size_penalty = 1.0 #@param{type:\"number\"}\n","false_class_penalty = 1.0 #@param{type:\"number\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," train_times = 4\n"," batch_size = 8\n"," learning_rate = 1e-4\n"," false_negative_penalty = 5.0\n"," false_positive_penalty = 1.0\n"," position_size_penalty = 1.0\n"," false_class_penalty = 1.0\n"," percentage_validation = 10\n","\n","\n","df_anno = []\n","dir_anno = Training_Source_annotations\n","for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n","df_anno = pd.DataFrame(df_anno)\n","\n","maxNobj = np.max(df_anno[\"Nobj\"])\n","totalNobj = np.sum(df_anno[\"Nobj\"])\n","\n","\n","class_obj = []\n","for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n","class_obj = np.array(class_obj)\n","\n","count = Counter(class_obj[class_obj != 'nan'])\n","print(count)\n","class_nm = list(count.keys())\n","class_labels = json.dumps(class_nm)\n","class_count = list(count.values())\n","asort_class_count = np.argsort(class_count)\n","\n","class_nm = np.array(class_nm)[asort_class_count]\n","class_count = np.array(class_count)[asort_class_count]\n","\n","xs = range(len(class_count))\n","\n","\n","#Show how many objects there are in the images\n","plt.figure(figsize=(15,8))\n","plt.subplot(1,2,1)\n","plt.hist(df_anno[\"Nobj\"].values,bins=50)\n","plt.title(\"Total number of objects in the dataset: {}\".format(totalNobj))\n","plt.xlabel('Number of objects per image')\n","plt.ylabel('Occurences')\n","\n","plt.subplot(1,2,2)\n","plt.barh(xs,class_count)\n","plt.yticks(xs,class_nm)\n","plt.title(\"The number of objects per class: {} classes in total\".format(len(count)))\n","plt.show()\n","\n","visualise_example = False\n","Use_pretrained_model = False\n","Use_Data_augmentation = False\n","\n","full_model_path = os.path.join(model_path,model_name)\n","if os.path.exists(full_model_path):\n"," print(bcolors.WARNING+'Model folder already exists and has been overwritten.'+bcolors.NORMAL)\n"," shutil.rmtree(full_model_path)\n","\n","# Create a new directory\n","os.mkdir(full_model_path)\n","\n","pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"0JPIts19QBBz"},"source":["#@markdown ###Play this cell to visualise an example image from your dataset to make sure annotations and images are properly matched.\r\n","import imageio\r\n","visualise_example = True\r\n","size = 1 \r\n","ind_random = np.random.randint(0,df_anno.shape[0],size=size)\r\n","img_dir=Training_Source\r\n","\r\n","file_suffix = os.path.splitext(os.listdir(Training_Source)[0])[1]\r\n","for irow in ind_random:\r\n"," row = df_anno.iloc[irow,:]\r\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\r\n"," # read in image\r\n"," img = imageio.imread(path)\r\n","\r\n"," plt.figure(figsize=(12,12))\r\n"," plt.imshow(img, cmap='gray') # plot image\r\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\r\n"," # for each object in the image, plot the bounding box\r\n"," for iplot in range(row[\"Nobj\"]):\r\n"," plt_rectangle(plt,\r\n"," label = row[\"bbx_{}_name\".format(iplot)],\r\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\r\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\r\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\r\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\r\n"," plt.axis('off')\r\n"," plt.savefig('/content/TrainingDataExample_YOLOv2.png',bbox_inches='tight',pad_inches=0)\r\n"," plt.show() ## show the plot\r\n","\r\n","pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset the `Use_Data_Augmentation` box can be unticked.\n","\n","Here, the images and bounding boxes are augmented by flipping and rotation. When doubling the dataset the images are only flipped. With each higher factor of augmentation the images added to the dataset represent one further rotation to the right by 90 degrees. 8x augmentation will give a dataset that is fully rotated and flipped once."]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#@markdown ##**Augmentation Options**\n","\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","multiply_dataset_by = 2 #@param {type:\"slider\", min:2, max:8, step:1}\n","\n","rotation_range = 90\n","\n","file_suffix = os.path.splitext(os.listdir(Training_Source)[0])[1]\n","if (Use_Data_augmentation):\n"," print('Data Augmentation enabled')\n"," # load images as NumPy arrays and append them to images list\n"," if os.path.exists(Training_Source+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Training_Source+'/.ipynb_checkpoints')\n"," \n"," images = []\n"," for index, file in enumerate(glob.glob(Training_Source+'/*'+file_suffix)):\n"," images.append(imageio.imread(file))\n"," \n"," # how many images we have\n"," print('Augmenting {} images'.format(len(images)))\n","\n"," # apply xml_to_csv() function to convert all XML files in images/ folder into labels.csv\n"," labels_df = xml_to_csv(Training_Source_annotations)\n"," labels_df.to_csv(('/content/original_labels.csv'), index=None)\n"," \n"," # Apply flip augmentation\n"," aug = iaa.OneOf([ \n"," iaa.Fliplr(1),\n"," iaa.Flipud(1)\n"," ])\n"," aug_2 = iaa.Affine(rotate=rotation_range, fit_output=True)\n"," aug_3 = iaa.Affine(rotate=rotation_range*2, fit_output=True)\n"," aug_4 = iaa.Affine(rotate=rotation_range*3, fit_output=True)\n","\n"," #Here we create a folder that will hold the original image dataset and the augmented image dataset\n"," augmented_training_source = os.path.dirname(Training_Source)+'/'+os.path.basename(Training_Source)+'_augmentation'\n"," if os.path.exists(augmented_training_source):\n"," shutil.rmtree(augmented_training_source)\n"," os.mkdir(augmented_training_source)\n","\n"," #Here we create a folder that will hold the original image annotation dataset and the augmented image annotation dataset (the bounding boxes).\n"," augmented_training_source_annotation = os.path.dirname(Training_Source_annotations)+'/'+os.path.basename(Training_Source_annotations)+'_augmentation'\n"," if os.path.exists(augmented_training_source_annotation):\n"," shutil.rmtree(augmented_training_source_annotation)\n"," os.mkdir(augmented_training_source_annotation)\n","\n"," #Create the augmentation\n"," augmented_images_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'flip_', aug)\n"," \n"," # Concat resized_images_df and augmented_images_df together and save in a new all_labels.csv file\n"," all_labels_df = pd.concat([labels_df, augmented_images_df])\n"," all_labels_df.to_csv('/content/combined_labels.csv', index=False)\n","\n"," #Here we convert the new bounding boxes for the augmented images to PASCAL VOC .xml format\n"," def convert_to_xml(df,source,target_folder):\n"," grouped = df.groupby('filename')\n"," for file in os.listdir(source):\n"," #if file in grouped.filename:\n"," group_df = grouped.get_group(file)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1)\n"," #group_df = group_df.dropna(axis=0)\n"," writer = Writer(source+'/'+file,group_df.iloc[1]['width'],group_df.iloc[1]['height'])\n"," for i, row in group_df.iterrows():\n"," writer.addObject(row['class'],round(row['xmin']),round(row['ymin']),round(row['xmax']),round(row['ymax']))\n"," writer.save(target_folder+'/'+os.path.splitext(file)[0]+'.xml')\n"," convert_to_xml(all_labels_df,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #Second round of augmentation\n"," if multiply_dataset_by > 2:\n"," aug_labels_df_2 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_2_df = image_aug(aug_labels_df_2, augmented_training_source+'/', augmented_training_source+'/', 'rot1_90_', aug_2)\n"," all_aug_labels_df = pd.concat([augmented_images_df, augmented_images_2_df])\n"," #all_labels_df.to_csv('/content/all_labels_aug.csv', index=False)\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 3:\n"," print('Augmenting again')\n"," aug_labels_df_3 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_3_df = image_aug(aug_labels_df_3, augmented_training_source+'/', augmented_training_source+'/', 'rot2_90_', aug_2)\n"," all_aug_labels_df_3 = pd.concat([all_aug_labels_df, augmented_images_3_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_3,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #This is a preliminary remover of potential duplicates in the augmentation\n"," #Ideally, duplicates are not even produced, but this acts as a fail safe.\n"," if multiply_dataset_by==4:\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n"," if multiply_dataset_by > 4:\n"," print('And Again')\n"," aug_labels_df_4 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_4_df = image_aug(aug_labels_df_4, augmented_training_source+'/',augmented_training_source+'/','rot3_90_', aug_2)\n"," all_aug_labels_df_4 = pd.concat([all_aug_labels_df_3, augmented_images_4_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_4,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot3_90_rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_rot1_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n","\n"," if multiply_dataset_by > 5:\n"," print('And again')\n"," augmented_images_5_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_90_', aug_2)\n"," all_aug_labels_df_5 = pd.concat([all_aug_labels_df_4,augmented_images_5_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," \n"," convert_to_xml(all_aug_labels_df_5,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 6:\n"," print('And again')\n"," augmented_images_df_6 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_180_', aug_3)\n"," all_aug_labels_df_6 = pd.concat([all_aug_labels_df_5,augmented_images_df_6])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_6,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 7:\n"," print('And again')\n"," augmented_images_df_7 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_270_', aug_4)\n"," all_aug_labels_df_7 = pd.concat([all_aug_labels_df_6,augmented_images_df_7])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_7,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(Training_Source):\n"," shutil.copyfile(Training_Source+'/'+file,augmented_training_source+'/'+file)\n"," shutil.copyfile(Training_Source_annotations+'/'+os.path.splitext(file)[0]+'.xml',augmented_training_source_annotation+'/'+os.path.splitext(file)[0]+'.xml')\n"," # display new dataframe\n"," #augmented_images_df\n"," \n"," # os.chdir('/content/gdrive/My Drive/keras-yolo2')\n"," # #Change the name of the training folder\n"," # !sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$augmented_training_source/\\\",@g' config.json\n","\n"," # #Change annotation folder\n"," # !sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$augmented_training_source_annotation/\\\",@g' config.json\n","\n"," df_anno = []\n"," dir_anno = augmented_training_source_annotation\n"," for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n"," df_anno = pd.DataFrame(df_anno)\n","\n"," maxNobj = np.max(df_anno[\"Nobj\"])\n","\n"," #Write the annotations to a csv file\n"," #df_anno.to_csv(model_path+'/annot.csv', index=False)#header=False, sep=',')\n","\n"," #Show how many objects there are in the images\n"," plt.figure()\n"," plt.subplot(2,1,1)\n"," plt.hist(df_anno[\"Nobj\"].values,bins=50)\n"," plt.title(\"max N of objects per image={}\".format(maxNobj))\n"," plt.show()\n","\n"," #Show the classes and how many there are of each in the dataset\n"," class_obj = []\n"," for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n"," class_obj = np.array(class_obj)\n","\n"," count = Counter(class_obj[class_obj != 'nan'])\n"," print(count)\n"," class_nm = list(count.keys())\n"," class_labels = json.dumps(class_nm)\n"," class_count = list(count.values())\n"," asort_class_count = np.argsort(class_count)\n","\n"," class_nm = np.array(class_nm)[asort_class_count]\n"," class_count = np.array(class_count)[asort_class_count]\n","\n"," xs = range(len(class_count))\n","\n"," plt.subplot(2,1,2)\n"," plt.barh(xs,class_count)\n"," plt.yticks(xs,class_nm)\n"," plt.title(\"The number of objects per class: {} objects in total\".format(len(count)))\n"," plt.show()\n","\n","else:\n"," print('No augmentation will be used')\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"y7HVvJZuNU1t"},"source":["#@markdown ###Play this cell to visualise some example images from your **augmented** dataset to make sure annotations and images are properly matched.\r\n","if (Use_Data_augmentation):\r\n"," df_anno_aug = []\r\n"," dir_anno_aug = augmented_training_source_annotation\r\n"," for fnm in os.listdir(dir_anno_aug): \r\n"," if not fnm.startswith('.'): ## do not include hidden folders/files\r\n"," tree = ET.parse(os.path.join(dir_anno_aug,fnm))\r\n"," row = extract_single_xml_file(tree)\r\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\r\n"," df_anno_aug.append(row)\r\n"," df_anno_aug = pd.DataFrame(df_anno_aug)\r\n","\r\n"," size = 3 \r\n"," ind_random = np.random.randint(0,df_anno_aug.shape[0],size=size)\r\n"," img_dir=augmented_training_source\r\n","\r\n"," file_suffix = os.path.splitext(os.listdir(augmented_training_source)[0])[1]\r\n"," for irow in ind_random:\r\n"," row = df_anno_aug.iloc[irow,:]\r\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\r\n"," # read in image\r\n"," img = imageio.imread(path)\r\n","\r\n"," plt.figure(figsize=(12,12))\r\n"," plt.imshow(img, cmap='gray') # plot image\r\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\r\n"," # for each object in the image, plot the bounding box\r\n"," for iplot in range(row[\"Nobj\"]):\r\n"," plt_rectangle(plt,\r\n"," label = row[\"bbx_{}_name\".format(iplot)],\r\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\r\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\r\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\r\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\r\n"," plt.show() ## show the plot\r\n"," print('These are the augmented training images.')\r\n","\r\n","else:\r\n"," print('Data augmentation disabled.')\r\n","\r\n","# else:\r\n","# for irow in ind_random:\r\n","# row = df_anno.iloc[irow,:]\r\n","# path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\r\n","# # read in image\r\n","# img = imageio.imread(path)\r\n","\r\n","# plt.figure(figsize=(12,12))\r\n","# plt.imshow(img, cmap='gray') # plot image\r\n","# plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\r\n","# # for each object in the image, plot the bounding box\r\n","# for iplot in range(row[\"Nobj\"]):\r\n","# plt_rectangle(plt,\r\n","# label = row[\"bbx_{}_name\".format(iplot)],\r\n","# x1=row[\"bbx_{}_xmin\".format(iplot)],\r\n","# y1=row[\"bbx_{}_ymin\".format(iplot)],\r\n","# x2=row[\"bbx_{}_xmax\".format(iplot)],\r\n","# y2=row[\"bbx_{}_ymax\".format(iplot)])\r\n","# plt.show() ## show the plot\r\n","# print('These are the non-augmented training images.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a YOLOv2 model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pretrained network\n","\n","# Training_Source = \"\" #@param{type:\"string\"}\n","# Training_Source_annotation = \"\" #@param{type:\"string\"}\n","# Check if the right files exist\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","pretrained_model_path = \"\" #@param{type:\"string\"}\n","h5_file_path = pretrained_model_path+'/'+Weights_choice+'_weights.h5'\n","\n","if not os.path.exists(h5_file_path) and Use_pretrained_model:\n"," print('WARNING pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n","# os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","# !sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"$h5_file_path\\\",@g' config.json\n","\n","if Use_pretrained_model:\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4):\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," learning_rate = bestLearningRate\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," #bestLearningRate = learning_rate\n"," #lastLearningRate = learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","else:\n"," print('No pre-trained models will be used.')\n","\n"," \n"," # !sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 0,@g' config.json\n"," # !sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","\n","# with open(os.path.join(pretrained_model_path, 'Quality Control', 'lr.csv'),'r') as csvfile:\n","# csvRead = pd.read_csv(csvfile, sep=',')\n","# #print(csvRead)\n"," \n","# if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n","# print(\"pretrained network learning rate found\")\n","# #find the last learning rate\n","# lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n","# #Find the learning rate corresponding to the lowest validation loss\n","# min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n","# #print(min_val_loss)\n","# bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n","# if Weights_choice == \"last\":\n","# print('Last learning rate: '+str(lastLearningRate))\n","\n","# if Weights_choice == \"best\":\n","# print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n","# if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n","# bestLearningRate = initial_learning_rate\n","# lastLearningRate = initial_learning_rate\n","# print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.1. Start training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Start training\n","\n","# full_model_path = os.path.join(model_path,model_name)\n","# if os.path.exists(full_model_path):\n","# print(bcolors.WARNING+'Model folder already exists and has been overwritten.'+bcolors.NORMAL)\n","# shutil.rmtree(full_model_path)\n","\n","# # Create a new directory\n","# os.mkdir(full_model_path)\n","\n","# ------------\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","\n","#os.chdir('/content/drive/My Drive/Zero-Cost Deep-Learning to Enhance Microscopy/Various dataset/Detection_Dataset_2/BCCD.v2.voc')\n","#if not os.path.exists(model_path+'/full_raccoon.h5'):\n"," # !wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1NWbrpMGLc84ow-4gXn2mloFocFGU595s' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=1NWbrpMGLc84ow-4gXn2mloFocFGU595s\" -O full_yolo_raccoon.h5 && rm -rf /tmp/cookies.txt\n","\n","\n","full_model_file_path = full_model_path+'/best_weights.h5'\n","os.chdir('/content/gdrive/My Drive/keras-yolo2/')\n","\n","#Change backend name\n","!sed -i 's@\\\"backend\\\":.*,@\\\"backend\\\": \\\"$backend\\\",@g' config.json\n","\n","#Change the name of the training folder\n","!sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$Training_Source/\\\",@g' config.json\n","\n","#Change annotation folder\n","!sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$Training_Source_annotations/\\\",@g' config.json\n","\n","#Change the name of the saved model\n","!sed -i 's@\\\"saved_weights_name\\\":.*,@\\\"saved_weights_name\\\": \\\"$full_model_file_path\\\",@g' config.json\n","\n","#Change warmup epochs for untrained model\n","!sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 3,@g' config.json\n","\n","#When defining a new model we should reset the pretrained model parameter\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"No_pretrained_weights\\\",@g' config.json\n","\n","!sed -i 's@\\\"nb_epochs\\\":.*,@\\\"nb_epochs\\\": $number_of_epochs,@g' config.json\n","\n","!sed -i 's@\\\"train_times\\\":.*,@\\\"train_times\\\": $train_times,@g' config.json\n","!sed -i 's@\\\"batch_size\\\":.*,@\\\"batch_size\\\": $batch_size,@g' config.json\n","!sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","!sed -i 's@\\\"object_scale\":.*,@\\\"object_scale\\\": $false_negative_penalty,@g' config.json\n","!sed -i 's@\\\"no_object_scale\":.*,@\\\"no_object_scale\\\": $false_positive_penalty,@g' config.json\n","!sed -i 's@\\\"coord_scale\\\":.*,@\\\"coord_scale\\\": $position_size_penalty,@g' config.json\n","!sed -i 's@\\\"class_scale\\\":.*,@\\\"class_scale\\\": $false_class_penalty,@g' config.json\n","\n","#Write the annotations to a csv file\n","df_anno.to_csv(full_model_path+'/annotations.csv', index=False)#header=False, sep=',')\n","\n","!sed -i 's@\\\"labels\\\":.*@\\\"labels\\\": $class_labels@g' config.json\n","\n","\n","#Generate anchors for the bounding boxes\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","output = sp.getoutput('python ./gen_anchors.py -c ./config.json')\n","\n","anchors_1 = output.find(\"[\")\n","anchors_2 = output.find(\"]\")\n","\n","config_anchors = output[anchors_1:anchors_2+1]\n","!sed -i 's@\\\"anchors\\\":.*,@\\\"anchors\\\": $config_anchors,@g' config.json\n","\n","\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"$h5_file_path\\\",@g' config.json\n","\n","\n","# !sed -i 's@\\\"anchors\\\":.*,@\\\"anchors\\\": $config_anchors,@g' config.json\n","\n","\n","if Use_pretrained_model:\n"," !sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 0,@g' config.json\n"," !sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","\n","if Use_Data_augmentation:\n"," # os.chdir('/content/gdrive/My Drive/keras-yolo2')\n"," #Change the name of the training folder\n"," !sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$augmented_training_source/\\\",@g' config.json\n","\n"," #Change annotation folder\n"," !sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$augmented_training_source_annotation/\\\",@g' config.json\n","\n","\n","# ------------\n","\n","\n","\n","if os.path.exists(full_model_path+\"/Quality Control\"):\n"," shutil.rmtree(full_model_path+\"/Quality Control\")\n","os.makedirs(full_model_path+\"/Quality Control\")\n","\n","\n","start = time.time()\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","train('config.json', full_model_path, percentage_validation)\n","\n","shutil.copyfile('/content/gdrive/My Drive/keras-yolo2/config.json',full_model_path+'/config.json')\n","\n","if os.path.exists('/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5'):\n"," shutil.move('/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5',full_model_path+'/best_map_weights.h5')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_folder = full_model_path\n","\n","#print(os.path.join(model_path, model_name))\n","\n","QC_model_name = os.path.basename(QC_model_folder)\n","\n","if os.path.exists(QC_model_folder):\n"," print(\"The \"+QC_model_name+\" model will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/gdrive/My Drive/keras-yolo2/config.json'):\n"," os.remove('/content/gdrive/My Drive/keras-yolo2/config.json')\n"," shutil.copyfile(QC_model_folder+'/config.json','/content/gdrive/My Drive/keras-yolo2/config.json')\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","mAPDataFromCSV = []\n","with open(QC_model_folder+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n"," mAPDataFromCSV.append(float(row[2]))\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(20,15))\n","\n","plt.subplot(3,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(3,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","#plt.savefig(os.path.dirname(QC_model_folder)+'/Quality Control/lossCurvePlots.png')\n","#plt.show()\n","\n","plt.subplot(3,1,3)\n","plt.plot(epochNumber,mAPDataFromCSV, label='mAP score')\n","plt.title('mean average precision (mAP) vs. epoch number (linear scale)')\n","plt.ylabel('mAP score')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png',bbox_inches='tight', pad_inches=0)\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display an overlay of the input images ground-truth (solid lines) and predicted boxes (dashed lines). Additionally, the below cell will show the mAP value of the model on the QC data together with plots of the Precision-Recall curves for all the classes in the dataset. If you want to read in more detail about these scores, we recommend [this brief explanation](https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173).\n","\n"," The images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" should contain images (e.g. as .jpg)and annotations (.xml files)!\n","\n","Since the training saves three different models, for the best validation loss (`best_weights`), best average precision (`best_mAP_weights`) and the model after the last epoch (`last_weights`), you should choose which ones you want to use for quality control or prediction. We recommend using `best_map_weights` because they should yield the best performance on the dataset. However, it can be worth testing how well `best_weights` perform too.\n","\n","**mAP score:** This refers to the mean average precision of the model on the given dataset. This value gives an indication how precise the predictions of the classes on this dataset are when compared to the ground-truth. Values closer to 1 indicate a good fit.\n","\n","**Precision:** This is the proportion of the correct classifications (true positives) in all the predictions made by the model.\n","\n","**Recall:** This is the proportion of the detected true positives in all the detectable data."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Annotations_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#@markdown ##Choose which model you want to evaluate:\n","model_choice = \"best_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","file_suffix = os.path.splitext(os.listdir(Source_QC_folder)[0])[1]\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_folder+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","#Delete old csv with box predictions if one exists\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists(Source_QC_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Source_QC_folder+'/.ipynb_checkpoints')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","n_objects = []\n","for img in os.listdir(Source_QC_folder):\n"," full_image_path = Source_QC_folder+'/'+img\n"," print('----')\n"," print(img)\n"," n_obj = predict('config.json',QC_model_folder+'/'+model_choice+'.h5',full_image_path)\n"," n_objects.append(n_obj)\n"," K.clear_session()\n","\n","for img in os.listdir(Source_QC_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Source_QC_folder+'/'+img,QC_model_folder+\"/Quality Control/Prediction/\"+img)\n","\n","#Here, we open the config file to get the classes fro the GT labels\n","config_path = '/content/gdrive/My Drive/keras-yolo2/config.json'\n","with open(config_path) as config_buffer:\n"," config = json.load(config_buffer)\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(QC_model_folder+'/Quality Control/predicted_bounding_boxes_for_custom_ROI_QC.csv')\n","\n","F1_scores, AP, recall, precision = _calc_avg_precisions(config,Source_QC_folder,Annotations_QC_folder+'/',QC_model_folder+'/'+model_choice+'.h5',0.3,0.3)\n","\n","\n","\n","with open(QC_model_folder+\"/Quality Control/QC_results.csv\", \"r\") as file:\n"," x = from_csv(file)\n"," \n","print(x)\n","\n","mAP_score = sum(AP.values())/len(AP)\n","\n","print('mAP score for QC dataset: '+str(mAP_score))\n","\n","for i in range(len(AP)):\n"," if AP[i]!=0:\n"," fig = plt.figure(figsize=(8,4))\n"," if len(recall[i]) == 1:\n"," new_recall = np.linspace(0,list(recall[i])[0],10)\n"," new_precision = list(precision[i])*10\n"," fig = plt.figure(figsize=(3,2))\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', bbox_inches='tight', pad_inches=0)\n"," plt.show()\n"," else:\n"," new_recall = list(recall[i])\n"," new_recall.append(new_recall[len(new_recall)-1])\n"," new_precision = list(precision[i])\n"," new_precision.append(0)\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', bbox_inches='tight', pad_inches=0)\n"," plt.show()\n"," else:\n"," print('No object of class '+config['model']['labels'][i]+' was detected. This will lower the mAP score. Consider adding an image containing this class to your QC dataset to see if the model can detect this class at all.')\n","\n","\n","# --------------------------------------------------------------\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\n","\n","# This will display a randomly chosen dataset input and predicted output\n","\n","print('Below is an example input, prediction and ground truth annotation from your test dataset.')\n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","file_suffix = os.path.splitext(random_choice)[1]\n","\n","plt.figure(figsize=(30,15))\n","\n","### Display Raw input ###\n","\n","x = plt.imread(Source_QC_folder+\"/\"+random_choice)\n","plt.subplot(1,3,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest', cmap='gray')\n","plt.title('Input', fontsize = 12)\n","\n","### Display Predicted annotation ###\n","\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\n","for img in range(0,df_bbox2.shape[0]):\n"," df_bbox2.iloc[img]\n"," row = pd.DataFrame(df_bbox2.iloc[img])\n"," if row[img][0] == random_choice:\n"," row = row.dropna()\n"," image = imageio.imread(Source_QC_folder+'/'+row[img][0])\n"," #plt.figure(figsize=(12,12))\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," plt.imshow(image, cmap='gray') # plot image\n"," plt.title('Prediction', fontsize=12)\n"," for i in range(1,int(len(row)-1),6):\n"," plt_rectangle(plt,\n"," label = row[img][i+5],\n"," x1=row[img][i],#.format(iplot)],\n"," y1=row[img][i+1],\n"," x2=row[img][i+2],\n"," y2=row[img][i+3])#,\n"," #fontsize=8)\n","\n","\n","### Display GT Annotation ###\n","\n","df_anno_QC_gt = []\n","for fnm in os.listdir(Annotations_QC_folder): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(Annotations_QC_folder,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_QC_gt.append(row)\n","df_anno_QC_gt = pd.DataFrame(df_anno_QC_gt)\n","#maxNobj = np.max(df_anno_QC_gt[\"Nobj\"])\n","\n","for i in range(0,df_anno_QC_gt.shape[0]):\n"," if df_anno_QC_gt.iloc[i][\"fileID\"]+file_suffix == random_choice:\n"," row = df_anno_QC_gt.iloc[i]\n","\n","img = imageio.imread(Source_QC_folder+'/'+random_choice)\n","plt.subplot(1,3,3)\n","plt.axis('off')\n","plt.imshow(img, cmap='gray') # plot image\n","plt.title('Ground Truth annotations', fontsize=12)\n","\n","# for each object in the image, plot the bounding box\n","for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])#,\n"," #fontsize=8)\n","\n","### Show the plot ###\n","plt.savefig(QC_model_folder+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`Prediction_model_path`:** This should be the folder that contains your model."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","file_suffix = os.path.splitext(os.listdir(Data_folder)[0])[1]\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model and path to model folder:\n","\n","Prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Which model do you want to use?\n","model_choice = \"best_map_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_path = full_model_path\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/gdrive/My Drive/keras-yolo2/config.json'):\n"," os.remove('/content/gdrive/My Drive/keras-yolo2/config.json')\n"," shutil.copyfile(Prediction_model_path+'/config.json','/content/gdrive/My Drive/keras-yolo2/config.json')\n","\n","if os.path.exists(Prediction_model_path+'/'+model_choice+'.h5'):\n"," print(\"The \"+os.path.basename(Prediction_model_path)+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Provide the code for performing predictions and saving them\n","print(\"Images will be saved into folder:\", Result_folder)\n","\n","\n","# ----- Predictions ------\n","\n","start = time.time()\n","\n","#Remove any files that might be from the prediction of QC examples.\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_new.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names_new.csv')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","if os.path.exists(Data_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Data_folder+'/.ipynb_checkpoints')\n","\n","n_objects = []\n","for img in os.listdir(Data_folder):\n"," full_image_path = Data_folder+'/'+img\n"," n_obj = predict('config.json',Prediction_model_path+'/'+model_choice+'.h5',full_image_path)#,Result_folder)\n"," n_objects.append(n_obj)\n"," K.clear_session()\n","for img in os.listdir(Data_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Data_folder+'/'+img,Result_folder+'/'+img)\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," #shutil.move('/content/predicted_bounding_boxes.csv',Result_folder+'/predicted_bounding_boxes.csv')\n"," print('Bounding box labels and coordinates saved to '+ Result_folder)\n","else:\n"," print('For some reason the bounding box labels and coordinates were not saved. Check that your predictions look as expected.')\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(Result_folder+'/predicted_bounding_boxes_for_custom_ROI.csv')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tP1isF0PO4C1"},"source":["## **6.2. Inspect the predicted output**\r\n","---\r\n","\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"ypLeYWnzO6tv"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\r\n","import random\r\n","from matplotlib.pyplot import imread\r\n","# This will display a randomly chosen dataset input and predicted output\r\n","random_choice = random.choice(os.listdir(Data_folder))\r\n","print(random_choice)\r\n","x = imread(Data_folder+\"/\"+random_choice)\r\n","\r\n","os.chdir(Result_folder)\r\n","y = imread(Result_folder+\"/\"+os.path.splitext(random_choice)[0]+'_detected'+file_suffix)\r\n","\r\n","plt.figure(figsize=(20,8))\r\n","\r\n","plt.subplot(1,3,1)\r\n","plt.axis('off')\r\n","plt.imshow(x, interpolation='nearest', cmap='gray')\r\n","plt.title('Input')\r\n","\r\n","plt.subplot(1,3,2)\r\n","plt.axis('off')\r\n","plt.imshow(y, interpolation='nearest')\r\n","plt.title('Predicted output');\r\n","\r\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\r\n","\r\n","#We need to edit this predicted_bounding_boxes_new.csv file slightly to display the bounding boxes\r\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\r\n","for img in range(0,df_bbox2.shape[0]):\r\n"," df_bbox2.iloc[img]\r\n"," row = pd.DataFrame(df_bbox2.iloc[img])\r\n"," if row[img][0] == random_choice:\r\n"," row = row.dropna()\r\n"," image = imageio.imread(Data_folder+'/'+row[img][0])\r\n"," #plt.figure(figsize=(12,12))\r\n"," plt.subplot(1,3,3)\r\n"," plt.axis('off')\r\n"," plt.title('Alternative Display of Prediction')\r\n"," plt.imshow(image, cmap='gray') # plot image\r\n","\r\n"," for i in range(1,int(len(row)-1),6):\r\n"," plt_rectangle(plt,\r\n"," label = row[img][i+5],\r\n"," x1=row[img][i],#.format(iplot)],\r\n"," y1=row[img][i+1],\r\n"," x2=row[img][i+2],\r\n"," y2=row[img][i+3])#,\r\n"," #fontsize=8)\r\n"," #plt.margins(0,0)\r\n"," #plt.subplots_adjust(left=0., right=1., top=1., bottom=0.)\r\n"," #plt.gca().xaxis.set_major_locator(plt.NullLocator())\r\n"," #plt.gca().yaxis.set_major_locator(plt.NullLocator())\r\n"," plt.savefig('/content/detected_cells.png',bbox_inches='tight',transparent=True,pad_inches=0)\r\n","plt.show() ## show the plot\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\r\n","#**Thank you for using YOLOv2!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb b/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb index 7a518a87..8dce4096 100644 --- a/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"fnet_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1-7F_2XEGwhGpwF5yRHgPDeFfrxYmLVVW","timestamp":1602528354120},{"file_id":"1-rjE9xp9Jrjkti3DT_bEDEhxx3kgvhEb","timestamp":1597394632737},{"file_id":"1G6lQzjd259Yoy_OozBhJolF4HraE52PG","timestamp":1591353884724},{"file_id":"1pSC680miQesRinU8Tjn7X6AmJXNtxUNI","timestamp":1591182507229},{"file_id":"1ajYZgvhQfpcUZ5YWlUeB-GUU_j-njsqw","timestamp":1589209398121},{"file_id":"1QiFrHg_cVlOl_yzu-RO9mMIrA2L1dXwj","timestamp":1587744376035},{"file_id":"1_S3UtNcuAaZhVc4yqlFDHc2eKq1x-ynn","timestamp":1587058075616},{"file_id":"1Gce_llcAX7yJTFZP2HiNpTL56gXR7PQ-","timestamp":1586854238074},{"file_id":"10l0NA5VWlqRvDlJRTxOiOUgN5LxEo2gy","timestamp":1586601464429},{"file_id":"1NSdad2BEDJZ16AO3SEEaG-ZSe0o4u3eY","timestamp":1586368373257},{"file_id":"1ubiSLYW3G4eNGNF31e2Vbw_3jMHJ9Y7M","timestamp":1585303720184},{"file_id":"1O6YzESEk9VFr6Nc6ijOAYCtiP80uuh7I","timestamp":1585248652537},{"file_id":"1DPrSIbf-ML-LIO2e4YhL1KedWVsVcFlT","timestamp":1585232236512},{"file_id":"1Qanbeybd44tHmdzKxTJAMDD4trFdCYwD","timestamp":1585049767771},{"file_id":"1Fr9Ea5QdUgK0CKfQKpq9KrxtxxAkSVwc","timestamp":1584619265981},{"file_id":"1RQ6XuOBIRaWgId2WKO2i-MMnXoKn_tNA","timestamp":1584541702239},{"file_id":"1mAvQKCCelwK8zPkAWFvKtiAsE_35KSpW","timestamp":1584533728194},{"file_id":"1LdMzIh-v-gUXnd6v9U2Ov28T-XpeT1PP","timestamp":1584463518766},{"file_id":"18Y0NabtThelB0uOAJlg7UbjHPYMEoCqW","timestamp":1584455459923},{"file_id":"1ZCnLW6HUl0bXrPa-54-bv_C9f6jYL0T4","timestamp":1584436296801},{"file_id":"1gTLXTd_rOpXmlktZz2yeEW62gY8ety-I","timestamp":1583941948440},{"file_id":"1gC_pmaDD73tD-yNoFGjHEolYfLd_7czL","timestamp":1583593255888},{"file_id":"17pZee2Vp0kCh3W8pfzRYk8asqk35mOfw","timestamp":1583335080677},{"file_id":"1KyYm3JglQpPYnf-aBLLiP-sFgi_A0Og1","timestamp":1583291424450},{"file_id":"1ZJCI2p66noTaLCnVUQJkTR16ig6GAqAx","timestamp":1576151149296}],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"C-wdtVN5KUFi"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"cell_type":"markdown","metadata":{"id":"Qt5Yt1vsD163"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"zwILBhMkzKp_"},"source":["#**0. Before getting started**\n","---\n","\n"," This notebook provides two opportunities: firstly, to download and train Fnet with data published in the original manuscript or secondly, to upload a personal dataset and train Fnet on it.\n"," The notebook may require a large amount of disk space. If using the datasets from the paper, the available disk space on the user's google drive should contain at least 40GB."]},{"cell_type":"markdown","metadata":{"id":"pcNfrIVpNZC-"},"source":["---\n","**Data Format**\n","\n"," **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"I0aF5U_Y0IFW"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"EBHobPtQ8wx7"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"UphYcwdDS8yO","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kRVmtCZB9OQ2"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"QTEFQc6j9RTv","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yk96o-_u-27d"},"source":["#**2. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"cell_type":"code","metadata":{"id":"BbYpGlfskzrO","cellView":"form"},"source":["Notebook_version = ['1.11']\n","\n","!pip install fpdf\n","\n","#@markdown ##Play this cell to download fnet to your drive. If it is already installed this will only install the fnet dependencies.\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import sys\n","import numpy as np\n","import shutil\n","import os\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","import datetime\n","import time\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","#clone fnet from github to colab\n","#!pip install -U scipy==1.2.0\n","#!pip install matplotlib==2.2.3\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet'):\n"," !git clone -b release_1 --single-branch https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n"," shutil.move('/content/pytorch_fnet','/content/gdrive/My Drive/pytorch_fnet')\n","!pip install -U scipy==1.2.0\n","!pip install matplotlib==2.2.3\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","#from skimage.util import img_as_uint\n","import matplotlib as mpl\n","#from scipy import signal\n","#from scipy import ndimage\n","\n","\n","#This function replaces the old default files with new values\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def insert_line_to_file(filepath,line_number,insertion):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def add_validation(filepath,line_number,insert,append):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not 'PATH_DATASET_VAL_CSV=' in f.read():\n"," contents.insert(line_number, insert)\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","print(\"-------------------\")\n","print(\"Libraries installed\")\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","# Exporting requirements.txt for local run\n","!pip freeze > requirements.txt\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JqCe6m-C_PrH"},"source":["#**3. Select your paths and parameters**\n","---"]},{"cell_type":"markdown","metadata":{"id":"w5NmDpJ4xvWE"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**\n","\n"," **Training Parameters**\n","\n"," **`percentage validation`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**"]},{"cell_type":"code","metadata":{"id":"PWxNzzgKu9Kb","cellView":"form"},"source":["#@markdown ###Datasets\n","#Datasets\n","from astropy.visualization import simple_norm\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","#replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","#dataset = model_name #The name of the dataset and the model will be the same\n","\n","#Here, we check if the dataset already exists. If not, copy the dataset from google drive to the data folder\n"," \n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name):\n"," #shutil.copytree(own_dataset,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)\n"," os.makedirs('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name) and not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name) and os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n","#Create a path_csv file to point to the training images\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#print(\"Selected \"+dataset+\" as training set\")\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Here we define the random set of training files to be used for validation\n","val_files = random.sample(source,round(len(source)*(percentage_validation/100)))\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_files:\n"," shutil.move('./'+model_name+'/'+source_name+'/'+file,'./'+model_name+'/Validation_Input/'+file)\n"," shutil.move('./'+model_name+'/'+target_name+'/'+file,'./'+model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+model_name+'/Validation_Input')\n","val_target = os.listdir('./'+model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","#Finally, we create a validation csv file to construct the validation dataset\n","with open(model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_signal)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Input/\"+val_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Target/\"+val_target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n","!chmod u+x train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","number_of_images = len(source)\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh\n","\n","#If new parameters are inserted here for training a model with the same name\n","#the previous training csv needs to be removed, to prevent the model using the old training split or paths.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name)\n","\n","#No Augmentation by default\n","Use_Data_augmentation = False\n","\n","#Load one randomly chosen training source file\n","random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Training_target)\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BCKcSJxkxi33"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"msrTTcPI1Cav"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images)."]},{"cell_type":"code","metadata":{"id":"u_YFN6Bd594L","cellView":"form"},"source":["from skimage import io\n","import numpy as np\n","\n","Use_Data_augmentation = True #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = False #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target', flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target'):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n"," \n"," if os.path.exists(Saving_path+'/augmented_validation_source'):\n"," shutil.rmtree(Saving_path+'/augmented_validation_source') \n"," os.mkdir(Saving_path+'/augmented_validation_source')\n"," \n"," if os.path.exists(Saving_path+'/augmented_validation_target'):\n"," shutil.rmtree(Saving_path+'/augmented_validation_target') \n"," os.mkdir(Saving_path+'/augmented_validation_target')\n"," \n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name,flip=Flip)\n"," rotation_aug('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input','/content/gdrive/My Drive/pytorch_fnet/data'+model_name+'/Validation_Target', aug_source_dest='augmented_validation_source', aug_target_dest='augmented_validation_target', flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," flip('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input','/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target', aug_source_dest='augmented_validation_source', aug_target_dest='augmented_validation_target')\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the target folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," #Fetch the path and extract the name of the Validation source folder\n"," Validation_source = Saving_path+'/augmented_validation_source'\n"," Validation_target = Saving_path+'/augmented_validation_target'\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.copytree(Validation_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n"," shutil.copytree(Validation_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","\n"," os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," #Redefine the source and target lists after moving the validation files\n"," source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," target = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," #Define Validation file lists\n"," val_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n"," val_target = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_signal)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Input/\"+val_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Target/\"+val_target[i]])\n","\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n"," with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n"," #Here, we ensure that the all files, including Validation are saved somewhere together for later access, e.g. for retraining.\n"," for image in os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input',image),Saving_path+'/augmented_source/'+image)\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target',image),Saving_path+'/augmented_target/'+image)\n"," \n"," #Here, we ensure that there aren't too many images in the buffer.\n"," #The best value will depend on the size of the images and the assigned GPU.\n"," #If too many images are loaded to the buffer the notebook will terminate the training as the RAM limit will be exceeded.\n"," if len(source)>110:\n"," number_of_images = 110\n"," else:\n"," number_of_images = len(source)\n","\n"," os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n"," !chmod u+x train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"heuBzM5JADYf"},"source":["#**4. Train the network**\n","---\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\n","\n","\n","###**Choose one of the options to train fnet**.\n","\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\n","\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\n","\n"," **Carefully read the options before starting training.**"]},{"cell_type":"markdown","metadata":{"id":"eLllOs_rA62U"},"source":["##**4.1. Start Training**\n","---\n","\n","####Play the cell below to start training. \n","\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3)."]},{"cell_type":"code","metadata":{"id":"hYGcj_XmT9GY","cellView":"form"},"source":["#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n","os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n","number_of_images = 50#@param{type:\"number\"}\n","!chmod u+x train_model.sh\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xe3TLu7M-3Dk","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","#Overwriting old models and saving them separately if True\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name)\n","\n","#This tifffile release runs error-free in this version of fnet.\n","!pip install tifffile==2019.7.26\n","\n","#Here we import an additional module to the functions.py file to run it without errors.\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","print('Let''s start the training!')\n","#Here we start the training\n","!./scripts/train_model.sh $model_name 0\n","\n","#After training overwrite any existing model in the model_path with the new trained model.\n","# if os.path.exists(model_path+'/'+model_name):\n","# shutil.rmtree(model_path+'/'+model_name)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv',model_path+'/'+model_name+'/'+model_name+'_val.csv')\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Label-free Prediction (fnet)'\n","day = datetime.now()\n","date_time = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and methods:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','torch','scipy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","#text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n","#if Use_pretrained_model:\n","# text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), pytorch (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'. The model was trained from the pretrained model: '+pretrained_model_path+'.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","# if Use_Default_Advanced_Parameters:\n","# pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\n","\"\"\".format(percentage_validation,steps,batch_size)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_Fnet.png').shape\n","pdf.image('/content/TrainingDataExample_Fnet.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+'_'+date_time+\"_training_report.pdf\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fpXr4JlCd5uV"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"cell_type":"code","metadata":{"id":"x41OhmO-hsX3","cellView":"form"},"source":["#@markdown ##Play this cell if your model training timed out and indicate where you want to save the last checkpoint.\n","\n","import shutil\n","import os\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","else:\n"," print('This model name does not exist in your saved_models folder. Make sure you have entered the name of the model that timed out.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QefQX9WUBz0G"},"source":["##**4.2. Training from a previously saved model**\n","---\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on.**\n","\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","metadata":{"id":"2-0m_-tF9oo-","cellView":"form"},"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\n","#@markdown Enter the paths of the datasets you want to continue training on.\n","\n","#Here we replace values in the old files\n","\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\n","#model_name = \"\" #@param {type:\"string\"}\n","\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\n","batch_size = 4 #@param {type:\"number\"}\n","\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\n","\n","#Move your model to fnet\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name):\n"," shutil.copytree(Pretrained_model_folder,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name)\n","\n","#Move the datasets into fnet\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","### number_of_images = len(os.listdir(Training_source)) ###\n","\n","#Change the train_model.sh file to include chosen dataset\n","!chmod u+x ./train_model.sh\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","\n","# We will use the same validation files from the training dataset as used before,\n","# This makes sure that the model is not validated with files it has seen in training before saving.\n","\n","#First we get the names of the validation files from the previous training which are saved in the validation csv.\n","val_source_list = []\n","\n","##CHECK THIS Prediction_model_name\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv'):\n"," shutil.copyfile(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv', 'r') as f:\n","#with open(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv', 'r') as f:\n"," contents = csv.reader(f,delimiter=',')\n"," for row in contents:\n"," val_source_list.append(row[0])\n","\n","#Get the file list without the header\n","val_source_list = val_source_list[1::]\n","\n","#Get only the file names and not the full path\n","for i in range(0,len(val_source_list)):\n"," val_source_list[i] = os.path.basename(os.path.normpath(val_source_list[i]))\n","\n","source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_source_list:\n"," #os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input/'+file)\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+Pretrained_model_name+'/Validation_Input')\n","val_target = os.listdir('./'+Pretrained_model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","shutil.copyfile(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","#Make a training csv file.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name)\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","with open(Pretrained_model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'.csv')\n","\n","#Find the number of previous training iterations (steps) from loss csv file\n","\n","with open(Pretrained_model_folder+'/losses.csv') as f:\n"," previous_steps = sum(1 for line in f)\n","print('continuing training after step '+str(previous_steps-1))\n","\n","print('To start re-training play section 4.2. below')\n","\n","#@markdown For how many additional steps do you want to train the model?\n","add_steps = 10000#@param {type:\"number\"}\n","\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n","new_steps = previous_steps + add_steps -1\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","#Edit train_model.sh file to include new total number of training epochs\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh\n","\n","#Load one randomly chosen training source file\n","random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Training_target)\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"vH3EzxbfD6Uk","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##4.2. Start re-training model\n","!pip install tifffile==2019.7.26\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/fnet')\n","\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","#Here we retrain the model on the chosen dataset.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x ./scripts/train_model.sh\n","!./scripts/train_model.sh $Pretrained_model_name 0\n","\n","if os.path.exists(Pretrained_model_folder):\n"," shutil.rmtree(Pretrained_model_folder)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name,Pretrained_model_folder)\n","\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv',Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","min, sec = divmod(dt, 60) \n","hour, min = divmod(min, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",min,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Label-free Prediction (fnet)'\n","day = datetime.now()\n","date_time = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and methods:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','torch','scipy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","# if Use_Default_Advanced_Parameters:\n","# pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\n","\"\"\".format(percentage_validation,steps,batch_size)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(model_path+'/TrainingDataExample_Fnet.png').shape\n","pdf.image(model_path+'/TrainingDataExample_Fnet.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Label-free Prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(Prediction_model_folder+'/'+Prediction_model_name+'_'+date_time+\"_training_report.pdf\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jwORXPtcqRHZ"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"rVBx2b2MpoFf","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"/content/gdrive/My Drive/Fnet_Models/Fnet_pdf_16\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"aNR6bAk6oZJD"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ratRdSDlcQ9G","cellView":"form"},"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YkhOGv3Hp2xI"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"vqSH6EQb4BwU","cellView":"form"},"source":["#Overwrite results folder if it already exists at the given location\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","from distutils.dir_util import copy_tree\n","\n","#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name)\n","\n","if Use_the_current_trained_model == True:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(QC_model_path+'/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = QC_model_name\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Source_QC_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name)\n"," # shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+target_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name)\n"," # shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+target_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name)\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," # if Use_the_current_trained_model == True:\n"," # writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n"," # # This currently assumes that the names are identical for source and target: see \"test_target\" variable is never used\n"," # else:\n"," # writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n"," if Use_the_current_trained_model ==True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i]])\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i]])\n","#We run the predictions\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","#Save the results\n","QC_results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/prediction_'+Predictions_name+'.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction/'+'Predicted_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Signal/'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Target/'+test_signal[i])\n","\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name)\n","\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","n_slices = img.shape[0]\n","z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," if len(test_GT_stack.shape) > 3:\n"," test_GT_stack = test_GT_stack.squeeze()\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n","\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","if len(img_GT.shape) > 3:\n"," img_GT = img_GT.squeeze()\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane],aspect='equal',cmap=cmap)\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Label-free prediction (fnet)'\n","#model_name = os.path.basename(QC_model_folder)\n","day = datetime.now()\n","date_time = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+date_time\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n","pdf.ln(3)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","html = \"\"\"\n","\n","\n","\"\"\"\n","with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," PSNR_PvsGT = header[4]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,NRMSE_PvsGT,PSNR_PvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," PSNR_PvsGT = row[4]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(PSNR_PvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}
{0}{1}{2}{3}{4}
\"\"\"\n"," \n","pdf.write_html(html)\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"V2ghLobACMy6"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"SMw0nWXeeC1N"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"8yoXStc8Lo27","cellView":"form"},"source":["#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.\n","#This is just in case you have already trained on a dataset with the same name\n","#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","Predictions_name = 'TempPredictionFolder'\n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(Results_folder+'/'+Predictions_name):\n"," shutil.rmtree(Results_folder+'/'+Predictions_name)\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = False #@param{type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","if Use_the_current_trained_model:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(Prediction_model_path+'/'+Prediction_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = Prediction_model_name\n","\n","# Get the name of the folder the test data is in\n","test_dataset_name = os.path.basename(os.path.normpath(Data_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","# We also allow the maximum number of images to be processed to be higher, i.e. 1000.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/N_IMAGES=.*/N_IMAGES=1000/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Data_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name)\n","test_signal = os.listdir(Data_folder)\n","\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," if Use_the_current_trained_model ==True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n","\n","#We run the predictions\n","start = time.time()\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Save the results\n","results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","for i in range(len(results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/prediction_'+Predictions_name+'.tiff', Results_folder+'/'+'Prediction_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/signal.tiff', Results_folder+'/'+test_signal[i])\n","\n","#Comment this out if you want to see the total original results from the prediction in the pytorch_fnet folder.\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"e2f-coEkCf58"},"source":["##**6.2. Assess predicted output**\n","---\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","metadata":{"id":"Uzv5rp6LrYQF","cellView":"form"},"source":["!pip install matplotlib==2.2.3\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from skimage import io\n","import os\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","#@markdown ###Select the slice would you like to view?\n","slice_number = 1#@param {type:\"number\"}\n","\n","def show_image(file=os.listdir(Data_folder)):\n"," os.chdir(Results_folder)\n","\n","#source_image = io.imread(test_signal[0])\n"," source_image = io.imread(os.path.join(Data_folder,file))\n"," prediction_image = io.imread(os.path.join(Results_folder,'Prediction_'+file))\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\n","\n","#Create the figure\n"," fig = plt.figure(figsize=(10,20))\n","\n"," #Setting up colours\n"," cmap = plt.cm.Greys\n","\n"," plt.subplot(1,2,1)\n"," print(prediction_image.shape)\n"," plt.imshow(source_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Source')\n"," plt.subplot(1,2,2)\n"," plt.imshow(prediction_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Prediction')\n","\n","interact(show_image, continuous_update=False);"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3dP2CrCVee1m"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"IXXOocFl3on8"},"source":["## **6.4. Purge unnecessary folders**\n","---\n"]},{"cell_type":"code","metadata":{"id":"emO85anSThPJ","cellView":"form"},"source":["#@markdown ##If you have checked that all your data is saved you can delete the pytorch_fnet folder from your drive by playing this cell.\n","\n","import shutil\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"l52zLRCn3z9v"},"source":["#**Thank you for using fnet!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"fnet_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611063104553},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n","**Data Format**\n","\n"," **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["#**2. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","!pip install fpdf\n","\n","#@markdown ##Play this cell to download fnet to your drive. If it is already installed this will only install the fnet dependencies.\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import sys\n","import numpy as np\n","import shutil\n","import os\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","from datetime import datetime\n","import time\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","#clone fnet from github to colab\n","#!pip install -U scipy==1.2.0\n","#!pip install matplotlib==2.2.3\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet'):\n"," !git clone -b release_1 --single-branch https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n"," shutil.move('/content/pytorch_fnet','/content/gdrive/My Drive/pytorch_fnet')\n","!pip install -U scipy==1.2.0\n","!pip install matplotlib==2.2.3\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","#from skimage.util import img_as_uint\n","import matplotlib as mpl\n","#from scipy import signal\n","#from scipy import ndimage\n","\n","\n","#This function replaces the old default files with new values\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def insert_line_to_file(filepath,line_number,insertion):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def add_validation(filepath,line_number,insert,append):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not 'PATH_DATASET_VAL_CSV=' in f.read():\n"," contents.insert(line_number, insert)\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","print(\"-------------------\")\n","print(\"Libraries installed\")\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free Prediction (fnet)'\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch','scipy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," #if Use_pretrained_model:\n"," # text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), pytorch (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'. The model was trained from the pretrained model: '+pretrained_model_path+'.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," # if Use_Default_Advanced_Parameters:\n"," # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\n"," \"\"\".format(percentage_validation,steps,batch_size)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Fnet.png').shape\n"," pdf.image('/content/TrainingDataExample_Fnet.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," if trained:\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," else:\n"," pdf.output('/content/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," \n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free prediction (fnet)'\n"," #model_name = os.path.basename(QC_model_folder)\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," PSNR_PvsGT = header[4]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,NRMSE_PvsGT,PSNR_PvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," PSNR_PvsGT = row[4]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(PSNR_PvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}
{0}{1}{2}{3}{4}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Exporting requirements.txt for local run\n","!pip freeze > requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**\n","\n"," **Training Parameters**\n","\n"," **`percentage validation`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Datasets\n","#Datasets\n","from astropy.visualization import simple_norm\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","#replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","#dataset = model_name #The name of the dataset and the model will be the same\n","\n","#Here, we check if the dataset already exists. If not, copy the dataset from google drive to the data folder\n"," \n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name):\n"," #shutil.copytree(own_dataset,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)\n"," os.makedirs('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name) and not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name) and os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n","#Create a path_csv file to point to the training images\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#print(\"Selected \"+dataset+\" as training set\")\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Here we define the random set of training files to be used for validation\n","val_files = random.sample(source,round(len(source)*(percentage_validation/100)))\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_files:\n"," shutil.move('./'+model_name+'/'+source_name+'/'+file,'./'+model_name+'/Validation_Input/'+file)\n"," shutil.move('./'+model_name+'/'+target_name+'/'+file,'./'+model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+model_name+'/Validation_Input')\n","val_target = os.listdir('./'+model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","#Finally, we create a validation csv file to construct the validation dataset\n","with open(model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_signal)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Input/\"+val_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Target/\"+val_target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n","!chmod u+x train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","number_of_images = len(source)\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh\n","\n","#If new parameters are inserted here for training a model with the same name\n","#the previous training csv needs to be removed, to prevent the model using the old training split or paths.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name)\n","\n","#No Augmentation by default\n","Use_Data_augmentation = False\n","\n","#Load one randomly chosen training source file\n","random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Training_target)\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images)."]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["from skimage import io\n","import numpy as np\n","\n","Use_Data_augmentation = True #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = False #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target', flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target'):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n"," \n"," if os.path.exists(Saving_path+'/augmented_validation_source'):\n"," shutil.rmtree(Saving_path+'/augmented_validation_source') \n"," os.mkdir(Saving_path+'/augmented_validation_source')\n"," \n"," if os.path.exists(Saving_path+'/augmented_validation_target'):\n"," shutil.rmtree(Saving_path+'/augmented_validation_target') \n"," os.mkdir(Saving_path+'/augmented_validation_target')\n"," \n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name,flip=Flip)\n"," rotation_aug('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input','/content/gdrive/My Drive/pytorch_fnet/data'+model_name+'/Validation_Target', aug_source_dest='augmented_validation_source', aug_target_dest='augmented_validation_target', flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," flip('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input','/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target', aug_source_dest='augmented_validation_source', aug_target_dest='augmented_validation_target')\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the target folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," #Fetch the path and extract the name of the Validation source folder\n"," Validation_source = Saving_path+'/augmented_validation_source'\n"," Validation_target = Saving_path+'/augmented_validation_target'\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.copytree(Validation_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n"," shutil.copytree(Validation_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","\n"," os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," #Redefine the source and target lists after moving the validation files\n"," source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," target = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," #Define Validation file lists\n"," val_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n"," val_target = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_signal)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Input/\"+val_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Target/\"+val_target[i]])\n","\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n"," with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n"," #Here, we ensure that the all files, including Validation are saved somewhere together for later access, e.g. for retraining.\n"," for image in os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input',image),Saving_path+'/augmented_source/'+image)\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target',image),Saving_path+'/augmented_target/'+image)\n"," \n"," #Here, we ensure that there aren't too many images in the buffer.\n"," #The best value will depend on the size of the images and the assigned GPU.\n"," #If too many images are loaded to the buffer the notebook will terminate the training as the RAM limit will be exceeded.\n"," if len(source)>110:\n"," number_of_images = 110\n"," else:\n"," number_of_images = len(source)\n","\n"," os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n"," !chmod u+x train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Nyf9ndiS7sL9"},"source":["#**4. Train the network**\r\n","---\r\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\r\n","\r\n","\r\n","###**Choose one of the options to train fnet**.\r\n","\r\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\r\n","\r\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\r\n","\r\n"," **Carefully read the options before starting training.**"]},{"cell_type":"markdown","metadata":{"id":"P9OJ0nlI71Rc"},"source":["##**4.1. Start Training**\r\n","---\r\n","\r\n","####Play the cell below to start training. \r\n","\r\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3).\r\n","\r\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"cellView":"form","id":"X8YHeSGr76je"},"source":["#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\r\n","os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\r\n","number_of_images = 40#@param{type:\"number\"}\r\n","!chmod u+x train_model.sh\r\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"7Ofm-71T8ABX"},"source":["#@markdown ##Start training\r\n","pdf_export(augmentation = Use_Data_augmentation)\r\n","start = time.time()\r\n","\r\n","#Overwriting old models and saving them separately if True\r\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\r\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name)\r\n","\r\n","#This tifffile release runs error-free in this version of fnet.\r\n","!pip install tifffile==2019.7.26\r\n","\r\n","#Here we import an additional module to the functions.py file to run it without errors.\r\n","\r\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\r\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\r\n","\r\n","if os.path.exists(model_path+'/'+model_name):\r\n"," shutil.rmtree(model_path+'/'+model_name)\r\n"," \r\n","print('Let''s start the training!')\r\n","#Here we start the training\r\n","!./scripts/train_model.sh $model_name 0\r\n","\r\n","#After training overwrite any existing model in the model_path with the new trained model.\r\n","# if os.path.exists(model_path+'/'+model_name):\r\n","# shutil.rmtree(model_path+'/'+model_name)\r\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\r\n","\r\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv',model_path+'/'+model_name+'/'+model_name+'_val.csv')\r\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\r\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\r\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\r\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\r\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\r\n","\r\n","\r\n","# Displaying the time elapsed for training\r\n","dt = time.time() - start\r\n","mins, sec = divmod(dt, 60) \r\n","hour, mins = divmod(mins, 60) \r\n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\r\n","\r\n","#Create a pdf document with training summary\r\n","\r\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bOdyjxWV8IrO"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"cell_type":"code","metadata":{"cellView":"form","id":"aWJxOy-R8OhH"},"source":["#@markdown ##Play this cell if your model training timed out and indicate where you want to save the last checkpoint.\r\n","\r\n","import shutil\r\n","import os\r\n","model_name = \"\" #@param {type:\"string\"}\r\n","model_path = \"\" #@param {type:\"string\"}\r\n","\r\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\r\n"," shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\r\n","else:\r\n"," print('This model name does not exist in your saved_models folder. Make sure you have entered the name of the model that timed out.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-JxxMmVr8Tw-"},"source":["##**4.2. Training from a previously saved model**\r\n","---\r\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on.**\r\n","\r\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","metadata":{"cellView":"form","id":"iDIgosht8U7F"},"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\r\n","#@markdown Enter the paths of the datasets you want to continue training on.\r\n","\r\n","#Here we replace values in the old files\r\n","\r\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\r\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\r\n","\r\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\r\n","#Clear the White space from train.sh\r\n","\r\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\r\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\r\n"," for line in inFile:\r\n"," if line.strip():\r\n"," outFile.write(line)\r\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\r\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\r\n","\r\n","#Datasets\r\n","\r\n","#Change checkpoints\r\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\r\n","\r\n","#Adapt Class Dataset for Tiff files\r\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\r\n","\r\n","\r\n","Training_source = \"\" #@param {type: \"string\"}\r\n","source_name = os.path.basename(os.path.normpath(Training_source))\r\n","\r\n","#Fetch the path and extract the name of the signal folder\r\n","Training_target = \"\" #@param {type: \"string\"}\r\n","target_name = os.path.basename(os.path.normpath(Training_target))\r\n","\r\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\r\n","#model_name = \"\" #@param {type:\"string\"}\r\n","\r\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\r\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\r\n","batch_size = 4 #@param {type:\"number\"}\r\n","\r\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\r\n","\r\n","#Move your model to fnet\r\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name):\r\n"," shutil.copytree(Pretrained_model_folder,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name)\r\n","\r\n","#Move the datasets into fnet\r\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name):\r\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\r\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\r\n","shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\r\n","shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\r\n","\r\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\r\n","\r\n","### number_of_images = len(os.listdir(Training_source)) ###\r\n","\r\n","#Change the train_model.sh file to include chosen dataset\r\n","!chmod u+x ./train_model.sh\r\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\r\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\r\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\r\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\r\n","\r\n","\r\n","# We will use the same validation files from the training dataset as used before,\r\n","# This makes sure that the model is not validated with files it has seen in training before saving.\r\n","\r\n","#First we get the names of the validation files from the previous training which are saved in the validation csv.\r\n","val_source_list = []\r\n","\r\n","##CHECK THIS Prediction_model_name\r\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv'):\r\n"," shutil.copyfile(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\r\n","\r\n","with open('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv', 'r') as f:\r\n","#with open(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv', 'r') as f:\r\n"," contents = csv.reader(f,delimiter=',')\r\n"," for row in contents:\r\n"," val_source_list.append(row[0])\r\n","\r\n","#Get the file list without the header\r\n","val_source_list = val_source_list[1::]\r\n","\r\n","#Get only the file names and not the full path\r\n","for i in range(0,len(val_source_list)):\r\n"," val_source_list[i] = os.path.basename(os.path.normpath(val_source_list[i]))\r\n","\r\n","source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\r\n","\r\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input'):\r\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\r\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target'):\r\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\r\n","\r\n","#Make validation directories\r\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\r\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\r\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\r\n","\r\n","#Move a random set of files from the training to the validation folders\r\n","for file in val_source_list:\r\n"," #os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\r\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input/'+file)\r\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target/'+file)\r\n","\r\n","#Redefine the source and target lists after moving the validation files\r\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\r\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\r\n","\r\n","#Define Validation file lists\r\n","val_signal = os.listdir('./'+Pretrained_model_name+'/Validation_Input')\r\n","val_target = os.listdir('./'+Pretrained_model_name+'/Validation_Target')\r\n","\r\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv'):\r\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\r\n","\r\n","shutil.copyfile(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\r\n","\r\n","#Make a training csv file.\r\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name):\r\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name)\r\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\r\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\r\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\r\n","with open(Pretrained_model_name+'.csv', 'w', newline='') as file:\r\n"," writer = csv.writer(file)\r\n"," writer.writerow([\"path_signal\",\"path_target\"])\r\n"," for i in range(0,len(source)):\r\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+target_name+\"/\"+target[i]])\r\n","\r\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'.csv')\r\n","\r\n","#Find the number of previous training iterations (steps) from loss csv file\r\n","\r\n","with open(Pretrained_model_folder+'/losses.csv') as f:\r\n"," previous_steps = sum(1 for line in f)\r\n","print('continuing training after step '+str(previous_steps-1))\r\n","\r\n","print('To start re-training play section 4.2. below')\r\n","\r\n","#@markdown For how many additional steps do you want to train the model?\r\n","add_steps = 10000#@param {type:\"number\"}\r\n","\r\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\r\n","new_steps = previous_steps + add_steps -1\r\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\r\n","\r\n","#Edit train_model.sh file to include new total number of training epochs\r\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh\r\n","\r\n","#Load one randomly chosen training source file\r\n","random_choice=random.choice(os.listdir(Training_source))\r\n","x = io.imread(Training_source+\"/\"+random_choice)\r\n","\r\n","#Find image Z dimension and select the mid-plane\r\n","Image_Z = x.shape[0]\r\n","mid_plane = int(Image_Z / 2)+1\r\n","\r\n","os.chdir(Training_target)\r\n","y = io.imread(Training_target+\"/\"+random_choice)\r\n","\r\n","f=plt.figure(figsize=(16,8))\r\n","plt.subplot(1,2,1)\r\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\r\n","plt.axis('off')\r\n","plt.title('Training Source (single Z plane)');\r\n","plt.subplot(1,2,2)\r\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\r\n","plt.axis('off')\r\n","plt.title('Training Target (single Z plane)');\r\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"5IXdFqhM8gO2"},"source":["start = time.time()\r\n","\r\n","#@markdown ##4.2. Start re-training model\r\n","!pip install tifffile==2019.7.26\r\n","\r\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/fnet')\r\n","\r\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\r\n","\r\n","#Here we retrain the model on the chosen dataset.\r\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\r\n","!chmod u+x ./scripts/train_model.sh\r\n","!./scripts/train_model.sh $Pretrained_model_name 0\r\n","\r\n","if os.path.exists(Pretrained_model_folder):\r\n"," shutil.rmtree(Pretrained_model_folder)\r\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name,Pretrained_model_folder)\r\n","\r\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\r\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\r\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\r\n","\r\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv',Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv')\r\n","# Displaying the time elapsed for training\r\n","dt = time.time() - start\r\n","min, sec = divmod(dt, 60) \r\n","hour, min = divmod(min, 60) \r\n","print(\"Time elapsed:\",hour, \"hour(s)\",min,\"min(s)\",round(sec),\"sec(s)\")\r\n","\r\n","#Create a pdf document with training summary\r\n","\r\n","# save FPDF() class into a \r\n","# variable pdf \r\n","from datetime import datetime\r\n","\r\n","class MyFPDF(FPDF, HTMLMixin):\r\n"," pass\r\n","\r\n","pdf = MyFPDF()\r\n","pdf.add_page()\r\n","pdf.set_right_margin(-1)\r\n","pdf.set_font(\"Arial\", size = 11, style='B') \r\n","\r\n","Network = 'Label-free Prediction (fnet)'\r\n","day = datetime.now()\r\n","date_time = str(day)[0:10]\r\n","\r\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\r\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \r\n"," \r\n","# add another cell \r\n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\r\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\r\n","pdf.ln(1)\r\n","\r\n","Header_2 = 'Information for your materials and methods:'\r\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\r\n","\r\n","all_packages = ''\r\n","for requirement in freeze(local_only=True):\r\n"," all_packages = all_packages+requirement+', '\r\n","#print(all_packages)\r\n","\r\n","#Main Packages\r\n","main_packages = ''\r\n","version_numbers = []\r\n","for name in ['tensorflow','numpy','torch','scipy']:\r\n"," find_name=all_packages.find(name)\r\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\r\n"," #Version numbers only here:\r\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\r\n","\r\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\r\n","cuda_version = cuda_version.stdout.decode('utf-8')\r\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\r\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\r\n","gpu_name = gpu_name.stdout.decode('utf-8')\r\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\r\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\r\n","#print(gpu_name)\r\n","\r\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\r\n","dataset_size = len(os.listdir(Training_source))\r\n","\r\n","text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\r\n","\r\n","pdf.set_font('')\r\n","pdf.set_font_size(10.)\r\n","pdf.multi_cell(190, 5, txt = text, align='L')\r\n","pdf.set_font('')\r\n","pdf.set_font('Arial', size = 10, style = 'B')\r\n","pdf.ln(1)\r\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\r\n","pdf.set_font('')\r\n","if Use_Data_augmentation:\r\n"," aug_text = 'The dataset was augmented by'\r\n"," if Rotation:\r\n"," aug_text = aug_text+'\\n- rotation'\r\n"," if Flip:\r\n"," aug_text = aug_text+'\\n- flipping'\r\n","else:\r\n"," aug_text = 'No augmentation was used for training.'\r\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\r\n","pdf.set_font('Arial', size = 11, style = 'B')\r\n","pdf.ln(1)\r\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\r\n","pdf.set_font('')\r\n","pdf.set_font_size(10.)\r\n","# if Use_Default_Advanced_Parameters:\r\n","# pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\r\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\r\n","pdf.ln(1)\r\n","html = \"\"\" \r\n","\r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n"," \r\n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\r\n","\"\"\".format(percentage_validation,steps,batch_size)\r\n","pdf.write_html(html)\r\n","\r\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\r\n","pdf.set_font(\"Arial\", size = 11, style='B')\r\n","pdf.ln(1)\r\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\r\n","pdf.set_font('')\r\n","pdf.set_font('Arial', size = 10, style = 'B')\r\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\r\n","pdf.set_font('')\r\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\r\n","pdf.set_font('')\r\n","pdf.set_font('Arial', size = 10, style = 'B')\r\n","pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\r\n","pdf.set_font('')\r\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\r\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\r\n","pdf.ln(1)\r\n","pdf.set_font('')\r\n","pdf.set_font('Arial', size = 10, style = 'B')\r\n","pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\r\n","pdf.set_font('')\r\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\r\n","pdf.ln(1)\r\n","pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\r\n","pdf.ln(1)\r\n","exp_size = io.imread(model_path+'/TrainingDataExample_Fnet.png').shape\r\n","pdf.image(model_path+'/TrainingDataExample_Fnet.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\r\n","pdf.ln(1)\r\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\r\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\r\n","ref_2 = '- Label-free Prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\r\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\r\n","pdf.ln(3)\r\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\r\n","pdf.set_font('Arial', size = 11, style='B')\r\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\r\n","\r\n","pdf.output(Prediction_model_folder+'/'+Prediction_model_name+'_'+date_time+\"_training_report.pdf\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#Overwrite results folder if it already exists at the given location\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","from distutils.dir_util import copy_tree\n","\n","#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name)\n","\n","if Use_the_current_trained_model == True:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(QC_model_path+'/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = QC_model_name\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Source_QC_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name)\n"," # shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+target_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name)\n"," # shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+target_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name)\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," # if Use_the_current_trained_model == True:\n"," # writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n"," # # This currently assumes that the names are identical for source and target: see \"test_target\" variable is never used\n"," # else:\n"," # writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n"," if Use_the_current_trained_model ==True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i]])\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i]])\n","#We run the predictions\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","#Save the results\n","QC_results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/prediction_'+Predictions_name+'.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction/'+'Predicted_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Signal/'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Target/'+test_signal[i])\n","\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name)\n","\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","n_slices = img.shape[0]\n","z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," if len(test_GT_stack.shape) > 3:\n"," test_GT_stack = test_GT_stack.squeeze()\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n","\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","if len(img_GT.shape) > 3:\n"," img_GT = img_GT.squeeze()\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane],aspect='equal',cmap=cmap)\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.\n","#This is just in case you have already trained on a dataset with the same name\n","#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","Predictions_name = 'TempPredictionFolder'\n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(Results_folder+'/'+Predictions_name):\n"," shutil.rmtree(Results_folder+'/'+Predictions_name)\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = True #@param{type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","if Use_the_current_trained_model:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(Prediction_model_path+'/'+Prediction_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = Prediction_model_name\n","\n","# Get the name of the folder the test data is in\n","test_dataset_name = os.path.basename(os.path.normpath(Data_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","# We also allow the maximum number of images to be processed to be higher, i.e. 1000.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/N_IMAGES=.*/N_IMAGES=1000/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Data_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name)\n","test_signal = os.listdir(Data_folder)\n","\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," if Use_the_current_trained_model ==True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n","\n","#We run the predictions\n","start = time.time()\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!/content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Save the results\n","results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')\n","for i in range(len(results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/prediction_'+Predictions_name+'.tiff', Results_folder+'/'+'Prediction_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/signal.tiff', Results_folder+'/'+test_signal[i])\n","\n","#Comment this out if you want to see the total original results from the prediction in the pytorch_fnet folder.\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bFtArIjs9tS9"},"source":["##**6.2. Assess predicted output**\r\n","---\r\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","metadata":{"cellView":"form","id":"66-af3rO9vM4"},"source":["!pip install matplotlib==2.2.3\r\n","import numpy as np\r\n","import matplotlib.pyplot as plt\r\n","from skimage import io\r\n","import os\r\n","from ipywidgets import interact\r\n","import ipywidgets as widgets\r\n","\r\n","#@markdown ###Select the slice would you like to view?\r\n","slice_number = 1#@param {type:\"number\"}\r\n","\r\n","def show_image(file=os.listdir(Data_folder)):\r\n"," os.chdir(Results_folder)\r\n","\r\n","#source_image = io.imread(test_signal[0])\r\n"," source_image = io.imread(os.path.join(Data_folder,file))\r\n"," prediction_image = io.imread(os.path.join(Results_folder,'Prediction_'+file))\r\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\r\n","\r\n","#Create the figure\r\n"," fig = plt.figure(figsize=(10,20))\r\n","\r\n"," #Setting up colours\r\n"," cmap = plt.cm.Greys\r\n","\r\n"," plt.subplot(1,2,1)\r\n"," print(prediction_image.shape)\r\n"," plt.imshow(source_image[slice_number], cmap = cmap, aspect = 'equal')\r\n"," plt.title('Source')\r\n"," plt.subplot(1,2,2)\r\n"," plt.imshow(prediction_image[slice_number], cmap = cmap, aspect = 'equal')\r\n"," plt.title('Prediction')\r\n","\r\n","interact(show_image, continuous_update=False);"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"89tlSWBC940z"},"source":["## **6.3. Download your predictions**\r\n","---\r\n","\r\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"RoiTamQC9_Pr"},"source":["## **6.4. Purge unnecessary folders**\r\n","---\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"3VStzQ0k-FUm"},"source":["#@markdown ##If you have checked that all your data is saved you can delete the pytorch_fnet folder from your drive by playing this cell.\r\n","\r\n","import shutil\r\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using fnet!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb b/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb index 811244c7..f6bef164 100644 --- a/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"pix2pix_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1oHT9zNqc_2AxhL3etEPmYAHSIZwWz91k","timestamp":1602598878693},{"file_id":"1mqcexfPBaIWuvMWWbJZUFtPoZoJJwrEA","timestamp":1589278334507},{"file_id":"159ARwlQE7-zi0EHxunOF_YPFLt-ZVU5x","timestamp":1587562499898},{"file_id":"1W-7NHehG5MRFILvZZzhPWWnOdJMkadb2","timestamp":1586332290412},{"file_id":"1pUetEQICxYWkYVaQIgdRH1EZBTl7oc2A","timestamp":1586292199692},{"file_id":"1MD36ZkM6XR9EuV12zimJmfCjzyeYZFWq","timestamp":1586269469061},{"file_id":"16A2mbaHzlEElntS8qkFBOsBvZG-mUeY6","timestamp":1586253795726},{"file_id":"1gJlcjOiSxr2buDOxmcFbT_d-GqwLjXtK","timestamp":1583343225796},{"file_id":"10yGI51WzHfgWgZAyE-EbkZFEvIOd6CP6","timestamp":1583171396283}],"collapsed_sections":[],"toc_visible":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **pix2pix**\n","\n","---\n","\n","pix2pix is a deep-learning method allowing image-to-image translation from one image domain type to another image domain type. It was first published by [Isola *et al.* in 2016](https://arxiv.org/abs/1611.07004). The image transformation requires paired images for training (supervised learning) and is made possible here by using a conditional Generative Adversarial Network (GAN) architecture to use information from the input image and obtain the equivalent translated image.\n","\n"," **This particular notebook enables image-to-image translation learned from paired dataset. If you are interested in performing unpaired image-to-image translation, you should consider using the CycleGAN notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Image-to-Image Translation with Conditional Adversarial Networks** by Isola *et al.* on arXiv in 2016 (https://arxiv.org/abs/1611.07004)\n","\n","The source code of the PyTorch implementation of pix2pix can be found here: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"N3azwKB9O0oW"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","metadata":{"id":"ByW6Vqdn9sYV","cellView":"form"},"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\n","#BSD License\n","\n","#For pix2pix software\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\n","#BSD License\n","\n","#For dcgan.torch software\n","\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," For pix2pix to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called Training_source and Training_target. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .PNG files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["# **1. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"BCPhV-pe-syw"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"VNZetvLiS1qV","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UBrnApIUBgxv"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install pix2pix and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = ['1.11']\n","\n","\n","#@markdown ##Install pix2pix and dependencies\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","\n","!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","import os\n","os.chdir('pytorch-CycleGAN-and-pix2pix/')\n","!pip install -r requirements.txt\n","!pip install fpdf\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","import glob\n","import os.path\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print('----------------------------')\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"BLmBseWbRvxL"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** pix2pix divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","metadata":{"id":"pIrTwJjzwV-D","cellView":"form"},"source":["\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.png\"\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 1#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","#To use pix2pix we need to organise the data in a way the network can understand\n","\n","Saving_path= \"/content/\"+model_name\n","#Saving_path= model_path+\"/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","imageA_folder = Saving_path+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","TrainA_Folder = Saving_path+\"/A/train\"\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/B/train\"\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imageio.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is at least bigger than 256\n","if patch_size < 256:\n"," patch_size = 256\n"," print (bcolors.WARNING + \" Your chosen patch_size is too small; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","\n","y = imageio.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","\n","plt.savefig('/content/TrainingDataExample_pix2pix.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"5LEowmfAWqPs"},"source":["## **3.2. Data augmentation**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"Flz3qoQrWv0v"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by [Augmentor.](https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"OsIBK-sywkfy","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 10 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"v-leE8pEWRkn"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a pix2pix model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n"]},{"cell_type":"code","metadata":{"id":"CbOcS3wiWV9w","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path = os.path.join(pretrained_model_path, \"latest_net_G.pth\")\n"," \n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"-A4ipz8gs3Ew"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from Section 3 to prepare the training data into a suitable format for training. **Your data will be copied in the google Colab \"content\" folder which may take some time depending on the size of your dataset.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"_V2ujGB60gDv","cellView":"form"},"source":["#@markdown ##Prepare the data for training\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","print('Copying training source data...')\n","for f in tqdm(os.listdir(Training_source_dir)):\n"," shutil.copyfile(Training_source_dir+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","print('Copying training target data...')\n","for f in tqdm(os.listdir(Training_target_dir)):\n"," shutil.copyfile(Training_target_dir+\"/\"+f, TrainB_Folder+\"/\"+f)\n","\n","#---------------------------------------------------------------------\n","\n","#--------------- Here we combined A and B images---------\n","os.chdir(\"/content\")\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","\n","\n","# pix2pix uses EPOCH without lr decay and EPOCH with lr decay, here we automatically choose half and half\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","print('------------------------')\n","print(\"Data ready for training\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session. **Pix2pix will save model checkpoints every 5 epochs.**"]},{"cell_type":"code","metadata":{"id":"eBD50tAgv5qf","cellView":"form"},"source":["\n","#@markdown ##Start training\n","\n","start = time.time()\n","\n","os.chdir(\"/content\")\n","\n","#--------------------------------- Command line inputs to change pix2pix paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n","\n","if Use_pretrained_model:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n","\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","# save FPDF() class into a \n","# variable pdf \n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'pix2pix'\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and method:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','torch']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a vanilla GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a vanilla GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n"," if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if flip_left_right != 0 or flip_top_bottom != 0:\n"," aug_text = aug_text+'\\n- flipping'\n"," if random_zoom_magnification != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if random_distortion != 0:\n"," aug_text = aug_text+'\\n- random distortion'\n"," if image_shear != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," if skew_image != 0:\n"," aug_text = aug_text+'\\n- image skewing'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
initial_learning_rate{3}
\n","\"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,initial_learning_rate)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread('/content/TrainingDataExample_pix2pix.png').shape\n","pdf.image('/content/TrainingDataExample_pix2pix.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- pix2pix: Isola, Phillip, et al. \"Image-to-image translation with conditional adversarial networks.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","if Use_Data_augmentation:\n"," ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XQjQb_J_Qyku"},"source":["##**4.3. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"2HbZd7rFqAad"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"NEBRRG8QyEDG"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","metadata":{"id":"EdcnkCr9Nbl8","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ry9qN2tlydXq"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"cell_type":"markdown","metadata":{"id":"1yauWCc78HKD"},"source":[" Pix2pix save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n","\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n"]},{"cell_type":"code","metadata":{"id":"2nBPucJdK3KS","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","import glob\n","import os.path\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Image_type = \"Grayscale\" #@param [\"Grayscale\", \"RGB\"]\n","\n","\n","# average function\n","def Average(lst): \n"," return sum(lst) / len(lst) \n","\n","# Create a quality control folder\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","\n","# Create a quality control/Prediction Folder\n","\n","QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"\n","\n","if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n","os.makedirs(QC_prediction_results)\n","\n","# Here we count how many images are in our folder to be predicted and we had a few\n","Nb_files_Data_folder = len(os.listdir(Source_QC_folder)) +10\n","\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","# Here we need to move the data to be analysed so that pix2pix can find them\n","\n","Saving_path_QC= \"/content/\"+QC_model_name+\"_images\"\n","\n","if os.path.exists(Saving_path_QC):\n"," shutil.rmtree(Saving_path_QC)\n","os.makedirs(Saving_path_QC)\n","\n","Saving_path_QC_folder = Saving_path_QC+\"/QC\"\n","\n","if os.path.exists(Saving_path_QC_folder):\n"," shutil.rmtree(Saving_path_QC_folder)\n","os.makedirs(Saving_path_QC_folder)\n","\n","\n","imageA_folder = Saving_path_QC_folder+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path_QC_folder+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path_QC_folder+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","testAB_folder = Saving_path_QC_folder+\"/AB/test\"\n","os.makedirs(testAB_folder)\n","\n","testA_Folder = Saving_path_QC_folder+\"/A/test\"\n","os.makedirs(testA_Folder)\n"," \n","testB_Folder = Saving_path_QC_folder+\"/B/test\"\n","os.makedirs(testB_Folder)\n","\n","QC_checkpoint_folders = \"/content/\"+QC_model_name\n","\n","if os.path.exists(QC_checkpoint_folders):\n"," shutil.rmtree(QC_checkpoint_folders)\n","os.makedirs(QC_checkpoint_folders)\n","\n","\n","for files in os.listdir(Source_QC_folder):\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, testA_Folder+\"/\"+files)\n","\n","for files in os.listdir(Target_QC_folder):\n"," shutil.copyfile(Target_QC_folder+\"/\"+files, testB_Folder+\"/\"+files)\n"," \n","#Here we create a merged folder containing only imageA\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = int(min(Image_Y, Image_X))\n","\n","patch_size_QC = Image_min_dim\n","\n","if not patch_size_QC % 256 == 0:\n"," patch_size_QC = ((int(patch_size_QC / 256)) * 256)\n"," print (\" Your image dimensions are not divisible by 256; therefore your images have now been resized to:\",patch_size_QC)\n","\n","if patch_size_QC < 256:\n"," patch_size_QC = 256\n","\n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))\n","\n","print(Nb_Checkpoint)\n","\n","\n","## Initiate list\n","\n","Checkpoint_list = []\n","Average_ssim_score_list = []\n","\n","\n","for j in range(1, len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))+1):\n"," checkpoints = j*5\n","\n"," if checkpoints == Nb_Checkpoint*5:\n"," checkpoints = \"latest\"\n","\n","\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n","\n"," Checkpoint_list.append(checkpoints)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n"," os.chdir(\"/content\")\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$QC_model_name\" --model pix2pix --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $patch_size_QC --crop_size $patch_size_QC --results_dir \"$QC_prediction_results\" --checkpoints_dir \"$QC_model_path\" --direction AtoB --num_test $Nb_files_Data_folder\n","#-----------------------------------------------------------------------------------\n","\n","#Here we need to move the data again and remove all the unnecessary folders\n","\n"," Checkpoint_name = \"test_\"+str(checkpoints)\n","\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n"," QC_results_images_files = os.listdir(QC_results_images)\n","\n"," for f in QC_results_images_files: \n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n","\n"," os.chdir(\"/content\") \n","\n"," #Here we clean up the extra files\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n","\n","\n"," #-------------------------------- QC for RGB ------------------------------------\n"," if Image_type == \"RGB\":\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n"," random_choice = random.choice(os.listdir(Source_QC_folder))\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\n"," \n"," \n"," # Initiate list\n"," ssim_score_list = [] \n","\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," shortname_no_PNG = i[:-4]\n"," \n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," \n"," test_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\"))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\"))\n"," \n"," \n"," # -------------------------------- Prediction --------------------------------\n"," \n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n"," \n"," #--------------------------- Here we normalise using histograms matching--------------------------------\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n"," \n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," \n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\n","\n"," if Image_type == \"Grayscale\":\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n","\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," \n"," \n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," ssim_score_list = []\n"," shortname_no_PNG = i[:-4]\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\"))\n"," \n"," test_GT = test_GT_raw[:,:,2]\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\"))\n"," \n"," test_source = test_source_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n"," \n"," test_prediction = test_prediction_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," \n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","# All data is now processed saved\n"," \n","\n","# -------------------------------- Display --------------------------------\n","\n","# Display the IoV vs Threshold plot\n","plt.figure(figsize=(20,5))\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n","plt.title('Checkpoints vs. SSIM')\n","plt.ylabel('SSIM')\n","plt.xlabel('Checkpoints')\n","plt.legend()\n","plt.savefig(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","\n","# -------------------------------- Display RGB --------------------------------\n","\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","if Image_type == \"RGB\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n","#Setting up colours\n"," cmap = None\n","\n","\n"," plt.figure(figsize=(15,15))\n","\n","# Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"), as_gray=False, pilmode=\"RGB\")\n"," \n"," plt.imshow(img_GT, cmap = cmap)\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_Source, cmap = cmap)\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n","\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n","\n"," plt.imshow(img_Prediction, cmap = cmap)\n"," plt.title('Prediction',fontsize=15)\n","\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","# -------------------------------- Display Grayscale --------------------------------\n","\n","if Image_type == \"Grayscale\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n"," \n","\n"," plt.figure(figsize=(20,20))\n"," # Currently only displays the last computed set, from memory\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"))\n","\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"))\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," \n"," \n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n"," \n","\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n","\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","#Make a pdf summary of the QC results\n","\n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'pix2pix'\n","\n","\n","day = datetime.now()\n","datetime_str = str(day)[0:10]\n","\n","Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(2)\n","pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png').shape\n","pdf.image(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(2)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(3)\n","pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n","if Image_type == 'RGB':\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/5), h = round(exp_size[0]/5))\n","if Image_type == 'Grayscale':\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","\n","pdf.ln(1)\n","for checkpoint in os.listdir(full_QC_model_path+'/Quality Control'):\n"," if os.path.isdir(os.path.join(full_QC_model_path,'Quality Control',checkpoint)) and checkpoint != 'Prediction':\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(70, 5, txt = 'Metrics for checkpoint: '+ str(checkpoint), align='L', ln=1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/'+str(checkpoint)+'/QC_metrics_'+QC_model_name+str(checkpoint)+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," pdf.write_html(html)\n"," pdf.ln(2)\n"," else:\n"," continue\n","\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- pix2pix: Isola, Phillip, et al. \"Image-to-image translation with conditional adversarial networks.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","pdf.ln(3)\n","reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Esqnbew8uznk"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\".\n"]},{"cell_type":"code","metadata":{"id":"yb3suNkfpNA9","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","import glob\n","import os.path\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G.pth')))+1\n","\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n","\n","\n"," \n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","\n","# Here we need to move the data to be analysed so that pix2pix can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","\n","imageA_folder = Saving_path_prediction+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path_prediction+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path_prediction+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","testAB_Folder = Saving_path_prediction+\"/AB/test\"\n","os.makedirs(testAB_Folder)\n","\n","testA_Folder = Saving_path_prediction+\"/A/test\"\n","os.makedirs(testA_Folder)\n"," \n","testB_Folder = Saving_path_prediction+\"/B/test\"\n","os.makedirs(testB_Folder)\n","\n","for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, testA_Folder+\"/\"+files)\n"," shutil.copyfile(Data_folder+\"/\"+files, testB_Folder+\"/\"+files)\n"," \n","# Here we create a merged A / A image for the prediction\n","os.chdir(\"/content\")\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","\n","# Here we count how many images are in our folder to be predicted and we had a few\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$Prediction_model_name\" --model pix2pix --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","Checkpoint_name = \"test_\"+str(checkpoint)\n","\n","\n","Prediction_results_folder = Result_folder+\"/\"+Prediction_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n","Prediction_results_images = os.listdir(Prediction_results_folder)\n","\n","for f in Prediction_results_images: \n"," if (f.endswith(\"_real_B.png\")): \n"," os.remove(Prediction_results_folder+\"/\"+f)\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EIe3CRD7XUxa"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"LmDP8xiwXTTL","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import os\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","\n","\n","random_choice_no_extension = os.path.splitext(random_choice)\n","\n","\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real_A.png\")\n","\n","\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake_B.png\")\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Prediction')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Rn9zpWpo0xNw"},"source":["\n","#**Thank you for using pix2pix!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"pix2pix_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1610978553958},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **pix2pix**\n","\n","---\n","\n","pix2pix is a deep-learning method allowing image-to-image translation from one image domain type to another image domain type. It was first published by [Isola *et al.* in 2016](https://arxiv.org/abs/1611.07004). The image transformation requires paired images for training (supervised learning) and is made possible here by using a conditional Generative Adversarial Network (GAN) architecture to use information from the input image and obtain the equivalent translated image.\n","\n"," **This particular notebook enables image-to-image translation learned from paired dataset. If you are interested in performing unpaired image-to-image translation, you should consider using the CycleGAN notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Image-to-Image Translation with Conditional Adversarial Networks** by Isola *et al.* on arXiv in 2016 (https://arxiv.org/abs/1611.07004)\n","\n","The source code of the PyTorch implementation of pix2pix can be found here: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"W7HfryEazzJE"},"source":["# **License**\r\n","\r\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","id":"4TTFT14b0J6n"},"source":["#@markdown ##Double click to see the license information\r\n","\r\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\r\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\r\n","\r\n","\r\n","\r\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\r\n","\r\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\r\n","#All rights reserved.\r\n","\r\n","#Redistribution and use in source and binary forms, with or without\r\n","#modification, are permitted provided that the following conditions are met:\r\n","\r\n","#* Redistributions of source code must retain the above copyright notice, this\r\n","# list of conditions and the following disclaimer.\r\n","\r\n","#* Redistributions in binary form must reproduce the above copyright notice,\r\n","# this list of conditions and the following disclaimer in the documentation\r\n","# and/or other materials provided with the distribution.\r\n","\r\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\r\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\r\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\r\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\r\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\r\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\r\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\r\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\r\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\r\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\r\n","\r\n","\r\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\r\n","#BSD License\r\n","\r\n","#For pix2pix software\r\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\r\n","#All rights reserved.\r\n","\r\n","#Redistribution and use in source and binary forms, with or without\r\n","#modification, are permitted provided that the following conditions are met:\r\n","\r\n","#* Redistributions of source code must retain the above copyright notice, this\r\n","# list of conditions and the following disclaimer.\r\n","\r\n","#* Redistributions in binary form must reproduce the above copyright notice,\r\n","# this list of conditions and the following disclaimer in the documentation\r\n","# and/or other materials provided with the distribution.\r\n","\r\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\r\n","#BSD License\r\n","\r\n","#For dcgan.torch software\r\n","\r\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\r\n","\r\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\r\n","\r\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\r\n","\r\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\r\n","\r\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\r\n","\r\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," For pix2pix to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called Training_source and Training_target. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .PNG files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install pix2pix and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","#@markdown ##Install pix2pix and dependencies\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","\n","!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","import os\n","os.chdir('pytorch-CycleGAN-and-pix2pix/')\n","!pip install -r requirements.txt\n","!pip install fpdf\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","import glob\n","import os.path\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print('----------------------------')\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'pix2pix'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a vanilla GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a vanilla GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n"," if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if flip_left_right != 0 or flip_top_bottom != 0:\n"," aug_text = aug_text+'\\n- flipping'\n"," if random_zoom_magnification != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if random_distortion != 0:\n"," aug_text = aug_text+'\\n- random distortion'\n"," if image_shear != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," if skew_image != 0:\n"," aug_text = aug_text+'\\n- image skewing'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
initial_learning_rate{3}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_pix2pix.png').shape\n"," pdf.image('/content/TrainingDataExample_pix2pix.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- pix2pix: Isola, Phillip, et al. \"Image-to-image translation with conditional adversarial networks.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," if augmentation:\n"," ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'pix2pix'\n","\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if Image_type == 'RGB':\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/5), h = round(exp_size[0]/5))\n"," if Image_type == 'Grayscale':\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," for checkpoint in os.listdir(full_QC_model_path+'/Quality Control'):\n"," if os.path.isdir(os.path.join(full_QC_model_path,'Quality Control',checkpoint)) and checkpoint != 'Prediction':\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(70, 5, txt = 'Metrics for checkpoint: '+ str(checkpoint), align='L', ln=1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/'+str(checkpoint)+'/QC_metrics_'+QC_model_name+str(checkpoint)+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," pdf.write_html(html)\n"," pdf.ln(2)\n"," else:\n"," continue\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- pix2pix: Isola, Phillip, et al. \"Image-to-image translation with conditional adversarial networks.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","!pip freeze > requirements.txt\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** pix2pix divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.png\"\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 1#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","#To use pix2pix we need to organise the data in a way the network can understand\n","\n","Saving_path= \"/content/\"+model_name\n","#Saving_path= model_path+\"/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","imageA_folder = Saving_path+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","TrainA_Folder = Saving_path+\"/A/train\"\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/B/train\"\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imageio.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is at least bigger than 256\n","if patch_size < 256:\n"," patch_size = 256\n"," print (bcolors.WARNING + \" Your chosen patch_size is too small; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","\n","y = imageio.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","\n","plt.savefig('/content/TrainingDataExample_pix2pix.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by [Augmentor.](https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 10 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a pix2pix model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n"]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path = os.path.join(pretrained_model_path, \"latest_net_G.pth\")\n"," \n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from Section 3 to prepare the training data into a suitable format for training. **Your data will be copied in the google Colab \"content\" folder which may take some time depending on the size of your dataset.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Prepare the data for training\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","print('Copying training source data...')\n","for f in tqdm(os.listdir(Training_source_dir)):\n"," shutil.copyfile(Training_source_dir+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","print('Copying training target data...')\n","for f in tqdm(os.listdir(Training_target_dir)):\n"," shutil.copyfile(Training_target_dir+\"/\"+f, TrainB_Folder+\"/\"+f)\n","\n","#---------------------------------------------------------------------\n","\n","#--------------- Here we combined A and B images---------\n","os.chdir(\"/content\")\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","\n","\n","# pix2pix uses EPOCH without lr decay and EPOCH with lr decay, here we automatically choose half and half\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","#Export of pdf summary of training parameters\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","print('------------------------')\n","print(\"Data ready for training\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session. **Pix2pix will save model checkpoints every 5 epochs.**\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","os.chdir(\"/content\")\n","\n","#--------------------------------- Command line inputs to change pix2pix paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n","\n","if Use_pretrained_model:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n","\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","# Export pdf summary after training to update document\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"markdown","metadata":{"id":"HQqBkYzT4hQS"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kittWWbs4pc8"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"cell_type":"markdown","metadata":{"id":"SeGNGf4A4ukf"},"source":[" Pix2pix save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\r\n","\r\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\r\n","\r\n","**1. The SSIM (structural similarity) map** \r\n","\r\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \r\n","\r\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\r\n","\r\n","**The output below shows the SSIM maps with the mSSIM**\r\n","\r\n","**2. The RSE (Root Squared Error) map** \r\n","\r\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\r\n","\r\n","\r\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\r\n","\r\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\r\n","\r\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\r\n","\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"VfF_oMpI4-Xl"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\r\n","\r\n","import glob\r\n","import os.path\r\n","\r\n","\r\n","Source_QC_folder = \"\" #@param{type:\"string\"}\r\n","Target_QC_folder = \"\" #@param{type:\"string\"}\r\n","\r\n","Image_type = \"Grayscale\" #@param [\"Grayscale\", \"RGB\"]\r\n","\r\n","\r\n","# average function\r\n","def Average(lst): \r\n"," return sum(lst) / len(lst) \r\n","\r\n","# Create a quality control folder\r\n","\r\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\r\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\r\n","\r\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\r\n","\r\n","\r\n","# Create a quality control/Prediction Folder\r\n","\r\n","QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"\r\n","\r\n","if os.path.exists(QC_prediction_results):\r\n"," shutil.rmtree(QC_prediction_results)\r\n","\r\n","os.makedirs(QC_prediction_results)\r\n","\r\n","# Here we count how many images are in our folder to be predicted and we had a few\r\n","Nb_files_Data_folder = len(os.listdir(Source_QC_folder)) +10\r\n","\r\n","# List images in Source_QC_folder\r\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \r\n","random_choice = random.choice(os.listdir(Source_QC_folder))\r\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\r\n","\r\n","#Find image XY dimension\r\n","Image_Y = x.shape[0]\r\n","Image_X = x.shape[1]\r\n","\r\n","Image_min_dim = min(Image_Y, Image_X)\r\n","\r\n","# Here we need to move the data to be analysed so that pix2pix can find them\r\n","\r\n","Saving_path_QC= \"/content/\"+QC_model_name+\"_images\"\r\n","\r\n","if os.path.exists(Saving_path_QC):\r\n"," shutil.rmtree(Saving_path_QC)\r\n","os.makedirs(Saving_path_QC)\r\n","\r\n","Saving_path_QC_folder = Saving_path_QC+\"/QC\"\r\n","\r\n","if os.path.exists(Saving_path_QC_folder):\r\n"," shutil.rmtree(Saving_path_QC_folder)\r\n","os.makedirs(Saving_path_QC_folder)\r\n","\r\n","\r\n","imageA_folder = Saving_path_QC_folder+\"/A\"\r\n","os.makedirs(imageA_folder)\r\n","\r\n","imageB_folder = Saving_path_QC_folder+\"/B\"\r\n","os.makedirs(imageB_folder)\r\n","\r\n","imageAB_folder = Saving_path_QC_folder+\"/AB\"\r\n","os.makedirs(imageAB_folder)\r\n","\r\n","testAB_folder = Saving_path_QC_folder+\"/AB/test\"\r\n","os.makedirs(testAB_folder)\r\n","\r\n","testA_Folder = Saving_path_QC_folder+\"/A/test\"\r\n","os.makedirs(testA_Folder)\r\n"," \r\n","testB_Folder = Saving_path_QC_folder+\"/B/test\"\r\n","os.makedirs(testB_Folder)\r\n","\r\n","QC_checkpoint_folders = \"/content/\"+QC_model_name\r\n","\r\n","if os.path.exists(QC_checkpoint_folders):\r\n"," shutil.rmtree(QC_checkpoint_folders)\r\n","os.makedirs(QC_checkpoint_folders)\r\n","\r\n","\r\n","for files in os.listdir(Source_QC_folder):\r\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, testA_Folder+\"/\"+files)\r\n","\r\n","for files in os.listdir(Target_QC_folder):\r\n"," shutil.copyfile(Target_QC_folder+\"/\"+files, testB_Folder+\"/\"+files)\r\n"," \r\n","#Here we create a merged folder containing only imageA\r\n","os.chdir(\"/content\")\r\n","\r\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\r\n","\r\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \r\n","random_choice = random.choice(os.listdir(Source_QC_folder))\r\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\r\n","\r\n","#Find image XY dimension\r\n","Image_Y = x.shape[0]\r\n","Image_X = x.shape[1]\r\n","\r\n","Image_min_dim = int(min(Image_Y, Image_X))\r\n","\r\n","patch_size_QC = Image_min_dim\r\n","\r\n","if not patch_size_QC % 256 == 0:\r\n"," patch_size_QC = ((int(patch_size_QC / 256)) * 256)\r\n"," print (\" Your image dimensions are not divisible by 256; therefore your images have now been resized to:\",patch_size_QC)\r\n","\r\n","if patch_size_QC < 256:\r\n"," patch_size_QC = 256\r\n","\r\n","\r\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))\r\n","\r\n","print(Nb_Checkpoint)\r\n","\r\n","\r\n","## Initiate list\r\n","\r\n","Checkpoint_list = []\r\n","Average_ssim_score_list = []\r\n","\r\n","\r\n","for j in range(1, len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))+1):\r\n"," checkpoints = j*5\r\n","\r\n"," if checkpoints == Nb_Checkpoint*5:\r\n"," checkpoints = \"latest\"\r\n","\r\n","\r\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\r\n","\r\n"," Checkpoint_list.append(checkpoints)\r\n","\r\n","\r\n"," # Create a quality control/Prediction Folder\r\n","\r\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\r\n","\r\n"," if os.path.exists(QC_prediction_results):\r\n"," shutil.rmtree(QC_prediction_results)\r\n","\r\n"," os.makedirs(QC_prediction_results)\r\n","\r\n","\r\n"," # Create a quality control/Prediction Folder\r\n","\r\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\r\n","\r\n"," if os.path.exists(QC_prediction_results):\r\n"," shutil.rmtree(QC_prediction_results)\r\n","\r\n"," os.makedirs(QC_prediction_results)\r\n","\r\n","\r\n","#---------------------------- Predictions are performed here ----------------------\r\n"," os.chdir(\"/content\")\r\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$QC_model_name\" --model pix2pix --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $patch_size_QC --crop_size $patch_size_QC --results_dir \"$QC_prediction_results\" --checkpoints_dir \"$QC_model_path\" --direction AtoB --num_test $Nb_files_Data_folder\r\n","#-----------------------------------------------------------------------------------\r\n","\r\n","#Here we need to move the data again and remove all the unnecessary folders\r\n","\r\n"," Checkpoint_name = \"test_\"+str(checkpoints)\r\n","\r\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\r\n","\r\n"," QC_results_images_files = os.listdir(QC_results_images)\r\n","\r\n"," for f in QC_results_images_files: \r\n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\r\n","\r\n"," os.chdir(\"/content\") \r\n","\r\n"," #Here we clean up the extra files\r\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\r\n","\r\n","\r\n"," #-------------------------------- QC for RGB ------------------------------------\r\n"," if Image_type == \"RGB\":\r\n","# List images in Source_QC_folder\r\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \r\n"," random_choice = random.choice(os.listdir(Source_QC_folder))\r\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\r\n","\r\n"," def ssim(img1, img2):\r\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\r\n","\r\n","# Open and create the csv file that will contain all the QC metrics\r\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\r\n"," writer = csv.writer(file)\r\n","\r\n"," # Write the header in the csv file\r\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\r\n"," \r\n"," \r\n"," # Initiate list\r\n"," ssim_score_list = [] \r\n","\r\n","\r\n"," # Let's loop through the provided dataset in the QC folders\r\n","\r\n","\r\n"," for i in os.listdir(Source_QC_folder):\r\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\r\n"," print('Running QC on: '+i)\r\n","\r\n"," shortname_no_PNG = i[:-4]\r\n"," \r\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\r\n"," \r\n"," test_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\"))\r\n","\r\n"," # -------------------------------- Source test data --------------------------------\r\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\"))\r\n"," \r\n"," \r\n"," # -------------------------------- Prediction --------------------------------\r\n"," \r\n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\r\n"," \r\n"," #--------------------------- Here we normalise using histograms matching--------------------------------\r\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\r\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\r\n"," \r\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\r\n","\r\n"," # Calculate the SSIM maps\r\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\r\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\r\n","\r\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\r\n","\r\n"," #Save ssim_maps\r\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\r\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\r\n"," \r\n"," \r\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\r\n","\r\n"," #Here we calculate the ssim average for each image in each checkpoints\r\n","\r\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\r\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\r\n","\r\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\r\n","\r\n"," if Image_type == \"Grayscale\":\r\n"," def ssim(img1, img2):\r\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\r\n","\r\n","\r\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\r\n","\r\n","\r\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\r\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\r\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\r\n","\r\n","\r\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\r\n"," \r\n"," if dtype is not None:\r\n"," x = x.astype(dtype,copy=False)\r\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\r\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\r\n"," eps = dtype(eps)\r\n","\r\n"," try:\r\n"," import numexpr\r\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\r\n"," except ImportError:\r\n"," x = (x - mi) / ( ma - mi + eps )\r\n","\r\n"," if clip:\r\n"," x = np.clip(x,0,1)\r\n","\r\n"," return x\r\n","\r\n"," def norm_minmse(gt, x, normalize_gt=True):\r\n"," \r\n"," if normalize_gt:\r\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\r\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\r\n"," #x = x - np.mean(x)\r\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\r\n"," #gt = gt - np.mean(gt)\r\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\r\n"," return gt, scale * x\r\n","\r\n","# Open and create the csv file that will contain all the QC metrics\r\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\r\n"," writer = csv.writer(file)\r\n","\r\n"," # Write the header in the csv file\r\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \r\n","\r\n"," \r\n"," \r\n"," # Let's loop through the provided dataset in the QC folders\r\n","\r\n","\r\n"," for i in os.listdir(Source_QC_folder):\r\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\r\n"," print('Running QC on: '+i)\r\n","\r\n"," ssim_score_list = []\r\n"," shortname_no_PNG = i[:-4]\r\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\r\n"," test_GT_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\"))\r\n"," \r\n"," test_GT = test_GT_raw[:,:,2]\r\n","\r\n"," # -------------------------------- Source test data --------------------------------\r\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\"))\r\n"," \r\n"," test_source = test_source_raw[:,:,2]\r\n","\r\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\r\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\r\n","\r\n"," # -------------------------------- Prediction --------------------------------\r\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\r\n"," \r\n"," test_prediction = test_prediction_raw[:,:,2]\r\n","\r\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\r\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \r\n","\r\n","\r\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\r\n","\r\n"," # Calculate the SSIM maps\r\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\r\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\r\n","\r\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\r\n","\r\n"," #Save ssim_maps\r\n"," \r\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\r\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\r\n"," \r\n"," # Calculate the Root Squared Error (RSE) maps\r\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\r\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\r\n","\r\n"," # Save SE maps\r\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\r\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\r\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\r\n","\r\n","\r\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\r\n","\r\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\r\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\r\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\r\n"," \r\n"," # We can also measure the peak signal to noise ratio between the images\r\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\r\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\r\n","\r\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\r\n","\r\n"," #Here we calculate the ssim average for each image in each checkpoints\r\n","\r\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\r\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\r\n","\r\n","\r\n","# All data is now processed saved\r\n"," \r\n","\r\n","# -------------------------------- Display --------------------------------\r\n","\r\n","# Display the IoV vs Threshold plot\r\n","plt.figure(figsize=(20,5))\r\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\r\n","plt.title('Checkpoints vs. SSIM')\r\n","plt.ylabel('SSIM')\r\n","plt.xlabel('Checkpoints')\r\n","plt.legend()\r\n","plt.savefig(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\r\n","plt.show()\r\n","\r\n","\r\n","\r\n","# -------------------------------- Display RGB --------------------------------\r\n","\r\n","from ipywidgets import interact\r\n","import ipywidgets as widgets\r\n","\r\n","\r\n","if Image_type == \"RGB\":\r\n"," random_choice_shortname_no_PNG = shortname_no_PNG\r\n","\r\n"," @interact\r\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\r\n","\r\n"," random_choice_shortname_no_PNG = file[:-4]\r\n","\r\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\r\n"," df2 = df1.set_index(\"image #\", drop = False)\r\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\r\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\r\n","\r\n","#Setting up colours\r\n"," cmap = None\r\n","\r\n","\r\n"," plt.figure(figsize=(15,15))\r\n","\r\n","# Target (Ground-truth)\r\n"," plt.subplot(3,3,1)\r\n"," plt.axis('off')\r\n"," img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"), as_gray=False, pilmode=\"RGB\")\r\n"," \r\n"," plt.imshow(img_GT, cmap = cmap)\r\n"," plt.title('Target',fontsize=15)\r\n","\r\n","# Source\r\n"," plt.subplot(3,3,2)\r\n"," plt.axis('off')\r\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"), as_gray=False, pilmode=\"RGB\")\r\n"," plt.imshow(img_Source, cmap = cmap)\r\n"," plt.title('Source',fontsize=15)\r\n","\r\n","#Prediction\r\n"," plt.subplot(3,3,3)\r\n"," plt.axis('off')\r\n","\r\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\r\n","\r\n"," plt.imshow(img_Prediction, cmap = cmap)\r\n"," plt.title('Prediction',fontsize=15)\r\n","\r\n","\r\n","#SSIM between GT and Source\r\n"," plt.subplot(3,3,5)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False)\r\n","\r\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n","\r\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\r\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\r\n"," plt.title('Target vs. Source',fontsize=15)\r\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\r\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\r\n","\r\n","#SSIM between GT and Prediction\r\n"," plt.subplot(3,3,6)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False) \r\n","\r\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n","\r\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\r\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\r\n"," plt.title('Target vs. Prediction',fontsize=15)\r\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\r\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\r\n","\r\n","# -------------------------------- Display Grayscale --------------------------------\r\n","\r\n","if Image_type == \"Grayscale\":\r\n"," random_choice_shortname_no_PNG = shortname_no_PNG\r\n","\r\n"," @interact\r\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\r\n","\r\n"," random_choice_shortname_no_PNG = file[:-4]\r\n","\r\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\r\n"," df2 = df1.set_index(\"image #\", drop = False)\r\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\r\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\r\n","\r\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\r\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\r\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\r\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\r\n"," \r\n","\r\n"," plt.figure(figsize=(20,20))\r\n"," # Currently only displays the last computed set, from memory\r\n"," # Target (Ground-truth)\r\n"," plt.subplot(3,3,1)\r\n"," plt.axis('off')\r\n"," img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"))\r\n","\r\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\r\n"," plt.title('Target',fontsize=15)\r\n","\r\n","# Source\r\n"," plt.subplot(3,3,2)\r\n"," plt.axis('off')\r\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"))\r\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\r\n"," plt.title('Source',fontsize=15)\r\n","\r\n","#Prediction\r\n"," plt.subplot(3,3,3)\r\n"," plt.axis('off')\r\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\r\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\r\n"," plt.title('Prediction',fontsize=15)\r\n","\r\n","#Setting up colours\r\n"," cmap = plt.cm.CMRmap\r\n","\r\n","#SSIM between GT and Source\r\n"," plt.subplot(3,3,5)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False)\r\n","\r\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\r\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\r\n","\r\n"," \r\n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\r\n"," plt.title('Target vs. Source',fontsize=15)\r\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\r\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\r\n","\r\n","#SSIM between GT and Prediction\r\n"," plt.subplot(3,3,6)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False) \r\n"," \r\n"," \r\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\r\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\r\n","\r\n"," \r\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\r\n"," plt.title('Target vs. Prediction',fontsize=15)\r\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\r\n","\r\n","#Root Squared Error between GT and Source\r\n"," plt.subplot(3,3,8)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False)\r\n","\r\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\r\n"," \r\n","\r\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\r\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\r\n"," plt.title('Target vs. Source',fontsize=15)\r\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\r\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\r\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\r\n","\r\n","#Root Squared Error between GT and Prediction\r\n"," plt.subplot(3,3,9)\r\n","#plt.axis('off')\r\n"," plt.tick_params(\r\n"," axis='both', # changes apply to the x-axis and y-axis\r\n"," which='both', # both major and minor ticks are affected\r\n"," bottom=False, # ticks along the bottom edge are off\r\n"," top=False, # ticks along the top edge are off\r\n"," left=False, # ticks along the left edge are off\r\n"," right=False, # ticks along the right edge are off\r\n"," labelbottom=False,\r\n"," labelleft=False)\r\n","\r\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\r\n","\r\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\r\n","\r\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\r\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\r\n"," plt.title('Target vs. Prediction',fontsize=15)\r\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\r\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\r\n","\r\n","#Make a pdf summary of the QC results\r\n","\r\n","qc_pdf_export()\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\".\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","import glob\n","import os.path\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G.pth')))+1\n","\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n","\n","\n"," \n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","\n","# Here we need to move the data to be analysed so that pix2pix can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","\n","imageA_folder = Saving_path_prediction+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path_prediction+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path_prediction+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","testAB_Folder = Saving_path_prediction+\"/AB/test\"\n","os.makedirs(testAB_Folder)\n","\n","testA_Folder = Saving_path_prediction+\"/A/test\"\n","os.makedirs(testA_Folder)\n"," \n","testB_Folder = Saving_path_prediction+\"/B/test\"\n","os.makedirs(testB_Folder)\n","\n","for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, testA_Folder+\"/\"+files)\n"," shutil.copyfile(Data_folder+\"/\"+files, testB_Folder+\"/\"+files)\n"," \n","# Here we create a merged A / A image for the prediction\n","os.chdir(\"/content\")\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","\n","# Here we count how many images are in our folder to be predicted and we had a few\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$Prediction_model_name\" --model pix2pix --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","Checkpoint_name = \"test_\"+str(checkpoint)\n","\n","\n","Prediction_results_folder = Result_folder+\"/\"+Prediction_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n","Prediction_results_images = os.listdir(Prediction_results_folder)\n","\n","for f in Prediction_results_images: \n"," if (f.endswith(\"_real_B.png\")): \n"," os.remove(Prediction_results_folder+\"/\"+f)\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Pdnb77E15zLE"},"source":["## **6.2. Inspect the predicted output**\r\n","---\r\n","\r\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"CrEBdt9T53Eh"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\r\n","import os\r\n","# This will display a randomly chosen dataset input and predicted output\r\n","random_choice = random.choice(os.listdir(Data_folder))\r\n","\r\n","\r\n","random_choice_no_extension = os.path.splitext(random_choice)\r\n","\r\n","\r\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real_A.png\")\r\n","\r\n","\r\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake_B.png\")\r\n","\r\n","f=plt.figure(figsize=(16,8))\r\n","plt.subplot(1,2,1)\r\n","plt.imshow(x, interpolation='nearest')\r\n","plt.title('Input')\r\n","plt.axis('off');\r\n","\r\n","plt.subplot(1,2,2)\r\n","plt.imshow(y, interpolation='nearest')\r\n","plt.title('Prediction')\r\n","plt.axis('off');\r\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\r\n","#**Thank you for using pix2pix!**"]}]} \ No newline at end of file