diff --git a/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb index a0768985..19609eef 100644 --- a/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb @@ -152,8 +152,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "3u2mXn3XsWzd", - "cellView": "form" + "cellView": "form", + "id": "3u2mXn3XsWzd" }, "outputs": [], "source": [ @@ -203,8 +203,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "aGxvAcGT-rTq", - "cellView": "form" + "cellView": "form", + "id": "aGxvAcGT-rTq" }, "outputs": [], "source": [ @@ -493,7 +493,7 @@ " pdf.set_font('Arial', size = 10, style = 'B')\n", " pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n", " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", + " pdf.multi_cell(170, 5, txt = full_model_path, align = 'L')\n", " pdf.ln(1)\n", " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", " pdf.ln(1)\n", @@ -516,7 +516,7 @@ " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", " pdf.ln(1)\n", "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n", + " pdf.output(full_model_path+'/'+model_name+\"_training_report.pdf\")\n", "\n", "\n", "#Make a pdf summary of the QC results\n", @@ -835,6 +835,7 @@ "#@markdown ###Name of the model and path to model folder:\n", "model_name = \"\" #@param {type:\"string\"}\n", "model_path = \"\" #@param {type:\"string\"}\n", + "full_model_path = model_path+'/'+model_name\n", "\n", "# other parameters for training.\n", "#@markdown ###Training Parameters\n", @@ -867,7 +868,7 @@ "\n", "\n", "#here we check that no model with the same name already exist, if so print a warning\n", - "if os.path.exists(model_path+'/'+model_name):\n", + "if os.path.exists(full_model_path):\n", " print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n", " print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n", " \n", @@ -911,9 +912,7 @@ " patch_size = ((int(patch_size / 8)-1) * 8)\n", " print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n", "\n", - "\n", - "os.chdir(Training_target)\n", - "y = imread(Training_target+\"/\"+random_choice)\n", + "y = imread(os.path.join(Training_target, random_choice))\n", "\n", "f=plt.figure(figsize=(16,8))\n", "plt.subplot(1,2,1)\n", @@ -1129,8 +1128,7 @@ " \n", "\n", " for filename in os.listdir(Training_source_augmented):\n", - " os.chdir(Training_source_augmented)\n", - " os.rename(filename, filename.replace('_original', ''))\n", + " os.rename(os.path.join(Training_source_augmented,filename), os.path.join(Training_source_augmented,filename).replace('_original', ''))\n", " \n", " #Here we clean up the extra files\n", " shutil.rmtree(Augmented_folder)\n", @@ -1287,10 +1285,10 @@ "\n", "# --------------------- Here we delete the model folder if it already exist ------------------------\n", "\n", - "if os.path.exists(model_path+'/'+model_name):\n", + "if os.path.exists(full_model_path):\n", " print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\"+W)\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "\n", + " shutil.rmtree(full_model_path)\n", + "os.makedirs(full_model_path, exist_ok=True)\n", "\n", "\n", "# --------------------- Here we load the augmented data or the raw data ------------------------\n", @@ -1415,18 +1413,18 @@ "print(\"Training, done.\")\n", "\n", "# copy the .npz to the model's folder\n", - "shutil.copyfile(model_path+'/rawdata.npz',model_path+'/'+model_name+'/rawdata.npz')\n", + "shutil.copyfile(model_path+'/rawdata.npz',full_model_path+'/rawdata.npz')\n", "\n", "# convert the history.history dict to a pandas DataFrame: \n", "lossData = pd.DataFrame(history.history) \n", "\n", - "if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n", - " shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n", + "if os.path.exists(full_model_path+\"/Quality Control\"):\n", + " shutil.rmtree(full_model_path+\"/Quality Control\")\n", "\n", - "os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n", + "os.makedirs(full_model_path+\"/Quality Control\")\n", "\n", "# The training evaluation.csv is saved (overwrites the Files if needed). \n", - "lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n", + "lossDataCSVpath = full_model_path+'/Quality Control/training_evaluation.csv'\n", "with open(lossDataCSVpath, 'w') as f:\n", " writer = csv.writer(f)\n", " writer.writerow(['loss','val_loss', 'learning rate'])\n", @@ -1636,8 +1634,7 @@ "for filename in os.listdir(Source_QC_folder):\n", " img = imread(os.path.join(Source_QC_folder, filename))\n", " predicted = model_training.predict(img, axes='YX')\n", - " os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - " imsave(filename, predicted)\n", + " imsave(os.path.join(QC_model_path, QC_model_name, \"Quality Control\", \"Prediction\", filename), predicted)\n", "\n", "\n", "def ssim(img1, img2):\n", @@ -2025,7 +2022,6 @@ "random_choice = random.choice(os.listdir(Data_folder))\n", "x = imread(Data_folder+\"/\"+random_choice)\n", "\n", - "os.chdir(Result_folder)\n", "y = imread(Result_folder+\"/\"+random_choice)\n", "\n", "plt.figure(figsize=(16,8))\n", @@ -2095,8 +2091,8 @@ "accelerator": "GPU", "colab": { "machine_shape": "hm", - "toc_visible": true, - "provenance": [] + "provenance": [], + "toc_visible": true }, "kernelspec": { "display_name": "Python 3", @@ -2117,4 +2113,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb index 75b08f63..f56071ca 100644 --- a/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb @@ -165,7 +165,7 @@ "source": [ "#@markdown ##Install CARE and dependencies\n", "!pip uninstall -y -q tensorflow\n", - "!pip install -q tensorflow==2.5\n", + "!pip install -q tensorflow==2.8\n", "\n", "import tensorflow \n", "\n", @@ -259,7 +259,7 @@ " !pip freeze > $path\n", "\n", " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter)\n", + " df = pd.read_csv(path)\n", " mod_list = [m.split('.')[0] for m in after if not m in before]\n", " req_list_temp = df.values.tolist()\n", " req_list = [x[0] for x in req_list_temp]\n", @@ -958,8 +958,6 @@ "\n", "\n", "#Load one randomly chosen training target file\n", - "\n", - "os.chdir(Training_target)\n", "y = imread(Training_target+\"/\"+random_choice)\n", "\n", "f=plt.figure(figsize=(16,8))\n", @@ -1642,8 +1640,7 @@ " img = imread(os.path.join(Source_QC_folder, filename))\n", " n_slices = img.shape[0]\n", " predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n", - " os.chdir(path_metrics_save+'Prediction/')\n", - " imsave('Predicted_'+filename, predicted)\n", + " imsave(path_metrics_save+'Prediction/Predicted_'+filename, predicted)\n", "\n", "\n", "def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", diff --git a/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb index 65b2b206..9608afed 100644 --- a/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **CycleGAN**\n","\n","---\n","\n","CycleGAN is a method that can capture the characteristics of one image domain and learn how these characteristics can be translated into another image domain, all in the absence of any paired training examples. It was first published by [Zhu *et al.* in 2017](https://arxiv.org/abs/1703.10593). Unlike pix2pix, the image transformation performed does not require paired images for training (unsupervised learning) and is made possible here by using a set of two Generative Adversarial Networks (GANs) that learn to transform images both from the first domain to the second and vice-versa.\n","\n"," **This particular notebook enables unpaired image-to-image translation. If your dataset is paired, you should also consider using the pix2pix notebook.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks** from Zhu *et al.* published in arXiv in 2018 (https://arxiv.org/abs/1703.10593)\n","\n","The source code of the CycleGAN PyTorch implementation can be found in: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"jqvkQQkcuMmM"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"vCihhAzluRvI"},"outputs":[],"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\n","#BSD License\n","\n","#For pix2pix software\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\n","#BSD License\n","\n","#For dcgan.torch software\n","\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **0. Before getting started**\n","---\n"," To train CycleGAN, **you only need two folders containing PNG images**. The images do not need to be paired.\n","\n","While you do not need paired images to train CycleGAN, if possible, **we strongly recommend that you generate a paired dataset. This means that the same image needs to be acquired in the two conditions. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","\n"," Please note that you currently can **only use .png files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset (non-matching images)**\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset (matching images)**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install CycleGAN and dependencies**\n","---\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"fq21zJVFNASx"},"outputs":[],"source":["Notebook_version = '1.13.1'\n","Network = 'CycleGAN'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path)\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item)\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Install CycleGAN and dependencies\n","\n","#------- Code from the cycleGAN demo notebook starts here -------\n","\n","#Here, we install libraries which are not already included in Colab.\n","!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","import os\n","os.chdir('pytorch-CycleGAN-and-pix2pix/')\n","!pip install -r requirements.txt\n","!pip install fpdf2\n","\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","\n","from skimage.util import img_as_int\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","import torch\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," pdf.ln(1)\n"," \n"," # add another cell\n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," # print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and an least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by default'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.ln(1)\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
initial_learning_rate{3}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_cycleGAN.png').shape\n"," pdf.image('/content/TrainingDataExample_cycleGAN.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," pdf.ln(1)\n"," ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(1)\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n"," pdf.ln(1)\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'cycleGAN'\n","\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," pdf.ln(1)\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png').shape\n"," pdf.image(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," if Image_type == 'RGB':\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/5), h = round(exp_size[0]/5))\n"," if Image_type == 'Grayscale':\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," for checkpoint in os.listdir(full_QC_model_path+'Quality Control'):\n"," if os.path.isdir(os.path.join(full_QC_model_path,'Quality Control',checkpoint)):\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(70, 5, txt = 'Metrics for checkpoint: '+ str(checkpoint), align='L', ln=1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'Quality Control/'+str(checkpoint)+'/QC_metrics_'+QC_model_name+str(checkpoint)+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," pdf.write_html(html)\n"," pdf.ln(2)\n"," else:\n"," continue\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," pdf.ln(1)\n"," ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(1)\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n"," pdf.ln(1)\n","\n"," pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"zCvebubeSaGY"},"outputs":[],"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"outputs":[],"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** CycleGAN divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 4. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"ewpNJ_I0Mv47"},"outputs":[],"source":["\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.png\"\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","assert number_of_epochs > 5, \"Number of epochs should be greater than 5 in order to save model checkpoints.\"\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 2#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","\n","\n","#To use Cyclegan we need to organise the data in a way the model can understand\n","\n","Saving_path= \"/content/\"+model_name\n","#Saving_path= model_path+\"/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","TrainA_Folder = Saving_path+\"/trainA\"\n","if os.path.exists(TrainA_Folder):\n"," shutil.rmtree(TrainA_Folder)\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/trainB\"\n","if os.path.exists(TrainB_Folder):\n"," shutil.rmtree(TrainB_Folder)\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imageio.imread(Training_source+\"/\"+random_choice)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","random_choice_2 = random.choice(os.listdir(Training_target))\n","y = imageio.imread(Training_target+\"/\"+random_choice_2)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_cycleGAN.png',bbox_inches='tight',pad_inches=0)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by flipping the patches. \n","\n"," By default data augmentation is enabled."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"DMqWq5-AxnFU"},"outputs":[],"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CycleGAN model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"9vC2n-HeLdiJ"},"outputs":[],"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path_A = os.path.join(pretrained_model_path, \"latest_net_G_A.pth\")\n"," h5_file_path_B = os.path.join(pretrained_model_path, \"latest_net_G_B.pth\")\n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["# **4. Train the network**\n","---"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from 3. to prepare the training data into a suitable format for training."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"lIUAOJ_LMv5E"},"outputs":[],"source":["#@markdown ##Prepare the data for training\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","\n","for f in os.listdir(Training_source):\n"," shutil.copyfile(Training_source+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","for files in os.listdir(Training_target):\n"," shutil.copyfile(Training_target+\"/\"+files, TrainB_Folder+\"/\"+files)\n","\n","#---------------------------------------------------------------------\n","\n","# CycleGAN use number of EPOCH withouth lr decay and number of EPOCH with lr decay\n","\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","print(\"Data ready for training\")\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"iwNmp1PUzRDQ","scrolled":true},"outputs":[],"source":["\n","#@markdown ##Start training\n","\n","start = time.time()\n","\n","os.chdir(\"/content\")\n","\n","#--------------------------------- Command line inputs to change CycleGAN paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --no_flip\n","\n","if Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n"," \n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train --no_flip\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","# Save training summary as pdf\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","Unfortunately loss functions curve are not very informative for GAN network. Therefore we perform the QC here using a test dataset.\n","\n","\n","\n","\n","\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"1Wext8woxt_F"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"eAJzMwPA6tlH"},"outputs":[],"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"1CFbjvTpx5C3"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"q8tCfAadx96X"},"source":[" CycleGAN save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n","\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"q2T4t8NNyDZ6"},"outputs":[],"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Image_type = \"Grayscale\" #@param [\"Grayscale\", \"RGB\"]\n","\n","# average function\n","def Average(lst): \n"," return sum(lst) / len(lst) \n","\n","\n","# Create a quality control folder\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_QC= \"/content/\"+QC_model_name\n","\n","if os.path.exists(Saving_path_QC):\n"," shutil.rmtree(Saving_path_QC)\n","os.makedirs(Saving_path_QC)\n","\n","Saving_path_QC_folder = Saving_path_QC+\"_images\"\n","\n","if os.path.exists(Saving_path_QC_folder):\n"," shutil.rmtree(Saving_path_QC_folder)\n","os.makedirs(Saving_path_QC_folder)\n","\n","\n","#Here we copy and rename the all the checkpoint to be analysed\n","\n","for f in os.listdir(full_QC_model_path):\n"," shortname = f[:-6]\n"," shortname = shortname + \".pth\"\n"," if f.endswith(\"net_G_A.pth\"):\n"," shutil.copyfile(full_QC_model_path+f, Saving_path_QC+\"/\"+shortname)\n","\n","\n","for files in os.listdir(Source_QC_folder):\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, Saving_path_QC_folder+\"/\"+files)\n"," \n","\n","# This will find the image dimension of a randomly chosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = int(min(Image_Y, Image_X))\n","\n","Nb_Checkpoint = len(os.listdir(Saving_path_QC))\n","\n","print(Nb_Checkpoint)\n","\n","\n","\n","## Initiate list\n","\n","Checkpoint_list = []\n","Average_ssim_score_list = []\n","\n","\n","for j in range(1, len(os.listdir(Saving_path_QC))+1):\n"," checkpoints = j*5\n","\n"," if checkpoints == Nb_Checkpoint*5:\n"," checkpoints = \"latest\"\n","\n","\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n","\n"," Checkpoint_list.append(checkpoints)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n"," os.chdir(\"/content\")\n","\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_QC_folder\" --name \"$QC_model_name\" --model test --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$QC_prediction_results\" --checkpoints_dir \"/content/\"\n","\n","#-----------------------------------------------------------------------------------\n","\n","#Here we need to move the data again and remove all the unnecessary folders\n","\n"," Checkpoint_name = \"test_\"+str(checkpoints)\n","\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n"," QC_results_images_files = os.listdir(QC_results_images)\n","\n"," for f in QC_results_images_files: \n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n","\n"," os.chdir(\"/content\") \n","\n"," #Here we clean up the extra files\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n","\n","\n","#-------------------------------- QC for RGB ------------------------------------\n"," if Image_type == \"RGB\":\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n"," random_choice = random.choice(os.listdir(Source_QC_folder))\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\n"," \n"," \n"," # Initiate list\n"," ssim_score_list = [] \n","\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," shortname_no_PNG = i[:-4]\n"," \n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," \n"," # -------------------------------- Prediction --------------------------------\n"," \n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," #--------------------------- Here we normalise using histograms matching--------------------------------\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n"," \n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," \n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","\n","\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\n","\n"," if Image_type == \"Grayscale\":\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n","\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," \n"," \n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," ssim_score_list = []\n"," shortname_no_PNG = i[:-4]\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT_raw = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n"," \n"," test_GT = test_GT_raw[:,:,2]\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," test_source = test_source_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," test_prediction = test_prediction_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," \n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","# All data is now processed saved\n"," \n","\n","# -------------------------------- Display --------------------------------\n","\n","# Display the IoV vs Threshold plot\n","plt.figure(figsize=(20,5))\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n","plt.title('Checkpoints vs. SSIM')\n","plt.ylabel('SSIM')\n","plt.xlabel('Checkpoints')\n","plt.legend()\n","plt.savefig(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","\n","# -------------------------------- Display RGB --------------------------------\n","\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","if Image_type == \"RGB\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n","#Setting up colours\n"," \n"," cmap = None\n","\n"," plt.figure(figsize=(10,10))\n","\n","# Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_GT, cmap = cmap)\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(Source_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_Source, cmap = cmap)\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n","\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n","\n"," plt.imshow(img_Prediction, cmap = cmap)\n"," plt.title('Prediction',fontsize=15)\n","\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","# -------------------------------- Display Grayscale --------------------------------\n","\n","if Image_type == \"Grayscale\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n"," \n","\n"," plt.figure(figsize=(15,15))\n","\n"," cmap = None\n"," \n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=True, pilmode=\"RGB\")\n","\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99), cmap = 'gray')\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real.png\"))\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," \n"," \n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n"," \n","\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n","\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\"."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"y2TD5p7MZrEb"},"outputs":[],"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","import glob\n","import os.path\n","\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# Here we check that checkpoint exist, if not the closest one will be chosen \n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G_A.pth')))\n","print(Nb_Checkpoint)\n","\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n"," \n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","Saving_path_Data_folder = Saving_path_prediction+\"/testA\"\n","\n","if os.path.exists(Saving_path_Data_folder):\n"," shutil.rmtree(Saving_path_Data_folder)\n","os.makedirs(Saving_path_Data_folder)\n","\n","for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, Saving_path_Data_folder+\"/\"+files)\n","\n","\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","\n","\n","#Here we copy and rename the checkpoint to be used\n","\n","shutil.copyfile(full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G_A.pth\", full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G.pth\")\n","\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","print(Image_min_dim)\n","\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_Data_folder\" --name \"$Prediction_model_name\" --model test --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"SXqS_EhByhQ7"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"64emoATwylxM"},"outputs":[],"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import os\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","\n","\n","random_choice_no_extension = os.path.splitext(random_choice)\n","\n","\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real.png\")\n","\n","\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake.png\")\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Prediction')\n","plt.axis('off');\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"pE8vQZ7RWY_L"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This version now includes an automatic restart allowing to set the h5py library to v2.10.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","This version also now includes built-in version check and the version log that \n","\n","* This version also now includes built-in version check and the version log that you're reading now.\n","\n","\n","\n","\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using CycleGAN!**"]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":[],"machine_shape":"hm","name":"CycleGAN_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611059046709},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"}},"nbformat":4,"nbformat_minor":0} +{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **CycleGAN**\n","\n","---\n","\n","CycleGAN is a method that can capture the characteristics of one image domain and learn how these characteristics can be translated into another image domain, all in the absence of any paired training examples. It was first published by [Zhu *et al.* in 2017](https://arxiv.org/abs/1703.10593). Unlike pix2pix, the image transformation performed does not require paired images for training (unsupervised learning) and is made possible here by using a set of two Generative Adversarial Networks (GANs) that learn to transform images both from the first domain to the second and vice-versa.\n","\n"," **This particular notebook enables unpaired image-to-image translation. If your dataset is paired, you should also consider using the pix2pix notebook.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks** from Zhu *et al.* published in arXiv in 2018 (https://arxiv.org/abs/1703.10593)\n","\n","The source code of the CycleGAN PyTorch implementation can be found in: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"jqvkQQkcuMmM"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"vCihhAzluRvI"},"outputs":[],"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\n","#BSD License\n","\n","#For pix2pix software\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\n","#BSD License\n","\n","#For dcgan.torch software\n","\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **0. Before getting started**\n","---\n"," To train CycleGAN, **you only need two folders containing PNG images**. The images do not need to be paired.\n","\n","While you do not need paired images to train CycleGAN, if possible, **we strongly recommend that you generate a paired dataset. This means that the same image needs to be acquired in the two conditions. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","\n"," Please note that you currently can **only use .png files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset (non-matching images)**\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset (matching images)**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install CycleGAN and dependencies**\n","---\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"fq21zJVFNASx"},"outputs":[],"source":["Notebook_version = '1.13.1'\n","Network = 'CycleGAN'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path)\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item)\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Install CycleGAN and dependencies\n","\n","#------- Code from the cycleGAN demo notebook starts here -------\n","\n","#Here, we install libraries which are not already included in Colab.\n","!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","import os\n","!pip install -r pytorch-CycleGAN-and-pix2pix/requirements.txt\n","!pip install fpdf2\n","\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","\n","from skimage.util import img_as_int\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","import torch\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," pdf.ln(1)\n"," \n"," # add another cell\n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," # print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and an least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by default'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.ln(1)\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
initial_learning_rate{3}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_cycleGAN.png').shape\n"," pdf.image('/content/TrainingDataExample_cycleGAN.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," pdf.ln(1)\n"," ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(1)\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n"," pdf.ln(1)\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'cycleGAN'\n","\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," pdf.ln(1)\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png').shape\n"," pdf.image(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," if Image_type == 'RGB':\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/5), h = round(exp_size[0]/5))\n"," if Image_type == 'Grayscale':\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," for checkpoint in os.listdir(full_QC_model_path+'Quality Control'):\n"," if os.path.isdir(os.path.join(full_QC_model_path,'Quality Control',checkpoint)):\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(70, 5, txt = 'Metrics for checkpoint: '+ str(checkpoint), align='L', ln=1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'Quality Control/'+str(checkpoint)+'/QC_metrics_'+QC_model_name+str(checkpoint)+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," pdf.write_html(html)\n"," pdf.ln(2)\n"," else:\n"," continue\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," pdf.ln(1)\n"," ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(1)\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n"," pdf.ln(1)\n","\n"," pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"zCvebubeSaGY"},"outputs":[],"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"outputs":[],"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** CycleGAN divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 4. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"ewpNJ_I0Mv47"},"outputs":[],"source":["\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.png\"\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","assert number_of_epochs > 5, \"Number of epochs should be greater than 5 in order to save model checkpoints.\"\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 2#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","\n","\n","#To use Cyclegan we need to organise the data in a way the model can understand\n","\n","Saving_path= \"/content/\"+model_name\n","#Saving_path= model_path+\"/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","TrainA_Folder = Saving_path+\"/trainA\"\n","if os.path.exists(TrainA_Folder):\n"," shutil.rmtree(TrainA_Folder)\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/trainB\"\n","if os.path.exists(TrainB_Folder):\n"," shutil.rmtree(TrainB_Folder)\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imageio.imread(Training_source+\"/\"+random_choice)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","random_choice_2 = random.choice(os.listdir(Training_target))\n","y = imageio.imread(Training_target+\"/\"+random_choice_2)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_cycleGAN.png',bbox_inches='tight',pad_inches=0)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by flipping the patches. \n","\n"," By default data augmentation is enabled."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"DMqWq5-AxnFU"},"outputs":[],"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CycleGAN model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"9vC2n-HeLdiJ"},"outputs":[],"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path_A = os.path.join(pretrained_model_path, \"latest_net_G_A.pth\")\n"," h5_file_path_B = os.path.join(pretrained_model_path, \"latest_net_G_B.pth\")\n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["# **4. Train the network**\n","---"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from 3. to prepare the training data into a suitable format for training."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"lIUAOJ_LMv5E"},"outputs":[],"source":["#@markdown ##Prepare the data for training\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","\n","for f in os.listdir(Training_source):\n"," shutil.copyfile(Training_source+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","for files in os.listdir(Training_target):\n"," shutil.copyfile(Training_target+\"/\"+files, TrainB_Folder+\"/\"+files)\n","\n","#---------------------------------------------------------------------\n","\n","# CycleGAN use number of EPOCH withouth lr decay and number of EPOCH with lr decay\n","\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","print(\"Data ready for training\")\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"iwNmp1PUzRDQ","scrolled":true},"outputs":[],"source":["\n","#@markdown ##Start training\n","\n","start = time.time()\n","\n","#--------------------------------- Command line inputs to change CycleGAN paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --no_flip\n","\n","if Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n"," \n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train --no_flip\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","# Save training summary as pdf\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","Unfortunately loss functions curve are not very informative for GAN network. Therefore we perform the QC here using a test dataset.\n","\n","\n","\n","\n","\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"1Wext8woxt_F"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"eAJzMwPA6tlH"},"outputs":[],"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"1CFbjvTpx5C3"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"q8tCfAadx96X"},"source":[" CycleGAN save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n","\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"q2T4t8NNyDZ6"},"outputs":[],"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Image_type = \"Grayscale\" #@param [\"Grayscale\", \"RGB\"]\n","\n","# average function\n","def Average(lst): \n"," return sum(lst) / len(lst) \n","\n","\n","# Create a quality control folder\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_QC= \"/content/\"+QC_model_name\n","\n","if os.path.exists(Saving_path_QC):\n"," shutil.rmtree(Saving_path_QC)\n","os.makedirs(Saving_path_QC)\n","\n","Saving_path_QC_folder = Saving_path_QC+\"_images\"\n","\n","if os.path.exists(Saving_path_QC_folder):\n"," shutil.rmtree(Saving_path_QC_folder)\n","os.makedirs(Saving_path_QC_folder)\n","\n","\n","#Here we copy and rename the all the checkpoint to be analysed\n","\n","for f in os.listdir(full_QC_model_path):\n"," shortname = f[:-6]\n"," shortname = shortname + \".pth\"\n"," if f.endswith(\"net_G_A.pth\"):\n"," shutil.copyfile(full_QC_model_path+f, Saving_path_QC+\"/\"+shortname)\n","\n","\n","for files in os.listdir(Source_QC_folder):\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, Saving_path_QC_folder+\"/\"+files)\n"," \n","\n","# This will find the image dimension of a randomly chosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = int(min(Image_Y, Image_X))\n","\n","Nb_Checkpoint = len(os.listdir(Saving_path_QC))\n","\n","print(Nb_Checkpoint)\n","\n","\n","\n","## Initiate list\n","\n","Checkpoint_list = []\n","Average_ssim_score_list = []\n","\n","\n","for j in range(1, len(os.listdir(Saving_path_QC))+1):\n"," checkpoints = j*5\n","\n"," if checkpoints == Nb_Checkpoint*5:\n"," checkpoints = \"latest\"\n","\n","\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n","\n"," Checkpoint_list.append(checkpoints)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_QC_folder\" --name \"$QC_model_name\" --model test --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$QC_prediction_results\" --checkpoints_dir \"/content/\"\n","\n","#-----------------------------------------------------------------------------------\n","\n","#Here we need to move the data again and remove all the unnecessary folders\n","\n"," Checkpoint_name = \"test_\"+str(checkpoints)\n","\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n"," QC_results_images_files = os.listdir(QC_results_images)\n","\n"," for f in QC_results_images_files: \n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n","\n"," #Here we clean up the extra files\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n","\n","\n","#-------------------------------- QC for RGB ------------------------------------\n"," if Image_type == \"RGB\":\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n"," random_choice = random.choice(os.listdir(Source_QC_folder))\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\n"," \n"," \n"," # Initiate list\n"," ssim_score_list = [] \n","\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," shortname_no_PNG = i[:-4]\n"," \n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," \n"," # -------------------------------- Prediction --------------------------------\n"," \n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," #--------------------------- Here we normalise using histograms matching--------------------------------\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n"," \n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," \n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","\n","\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\n","\n"," if Image_type == \"Grayscale\":\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n","\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," \n"," \n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," ssim_score_list = []\n"," shortname_no_PNG = i[:-4]\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT_raw = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n"," \n"," test_GT = test_GT_raw[:,:,2]\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," test_source = test_source_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," test_prediction = test_prediction_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," \n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","# All data is now processed saved\n"," \n","\n","# -------------------------------- Display --------------------------------\n","\n","# Display the IoV vs Threshold plot\n","plt.figure(figsize=(20,5))\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n","plt.title('Checkpoints vs. SSIM')\n","plt.ylabel('SSIM')\n","plt.xlabel('Checkpoints')\n","plt.legend()\n","plt.savefig(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","\n","# -------------------------------- Display RGB --------------------------------\n","\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","if Image_type == \"RGB\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n","#Setting up colours\n"," \n"," cmap = None\n","\n"," plt.figure(figsize=(10,10))\n","\n","# Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_GT, cmap = cmap)\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(Source_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_Source, cmap = cmap)\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n","\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n","\n"," plt.imshow(img_Prediction, cmap = cmap)\n"," plt.title('Prediction',fontsize=15)\n","\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","# -------------------------------- Display Grayscale --------------------------------\n","\n","if Image_type == \"Grayscale\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n"," \n","\n"," plt.figure(figsize=(15,15))\n","\n"," cmap = None\n"," \n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=True, pilmode=\"RGB\")\n","\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99), cmap = 'gray')\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real.png\"))\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," \n"," \n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n"," \n","\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n","\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\"."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"y2TD5p7MZrEb"},"outputs":[],"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","import glob\n","import os.path\n","\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# Here we check that checkpoint exist, if not the closest one will be chosen \n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G_A.pth')))\n","print(Nb_Checkpoint)\n","\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n"," \n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","Saving_path_Data_folder = Saving_path_prediction+\"/testA\"\n","\n","if os.path.exists(Saving_path_Data_folder):\n"," shutil.rmtree(Saving_path_Data_folder)\n","os.makedirs(Saving_path_Data_folder)\n","\n","for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, Saving_path_Data_folder+\"/\"+files)\n","\n","\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","\n","\n","#Here we copy and rename the checkpoint to be used\n","\n","shutil.copyfile(full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G_A.pth\", full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G.pth\")\n","\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","print(Image_min_dim)\n","\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_Data_folder\" --name \"$Prediction_model_name\" --model test --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"SXqS_EhByhQ7"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"64emoATwylxM"},"outputs":[],"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import os\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","\n","\n","random_choice_no_extension = os.path.splitext(random_choice)\n","\n","\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real.png\")\n","\n","\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake.png\")\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Prediction')\n","plt.axis('off');\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"pE8vQZ7RWY_L"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This version now includes an automatic restart allowing to set the h5py library to v2.10.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","This version also now includes built-in version check and the version log that \n","\n","* This version also now includes built-in version check and the version log that you're reading now.\n","\n","\n","\n","\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using CycleGAN!**"]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":[],"machine_shape":"hm","name":"CycleGAN_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611059046709},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"}},"nbformat":4,"nbformat_minor":0} diff --git a/Colab_notebooks/fnet_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/fnet_2D_ZeroCostDL4Mic.ipynb index 58d0890b..3c99fc35 100644 --- a/Colab_notebooks/fnet_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/fnet_2D_ZeroCostDL4Mic.ipynb @@ -7,7 +7,7 @@ "id": "IkSguVy8Xv83" }, "source": [ - "#**Label-free prediction - fnet**\n", + "# **Label-free prediction - fnet**\n", "---\n", "\n", " \n", @@ -43,7 +43,7 @@ "\n", "\n", "---\n", - "###**Structure of a notebook**\n", + "### **Structure of a notebook**\n", "\n", "The notebook contains two types of cell: \n", "\n", @@ -52,7 +52,7 @@ "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", "\n", "---\n", - "###**Table of contents, Code snippets** and **Files**\n", + "### **Table of contents, Code snippets** and **Files**\n", "\n", "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", "\n", @@ -67,7 +67,7 @@ "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", "\n", "---\n", - "###**Making changes to the notebook**\n", + "### **Making changes to the notebook**\n", "\n", "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", "\n", @@ -82,7 +82,7 @@ "id": "gKDLkLWUd-YX" }, "source": [ - "#**0. Before getting started**\n", + "# **0. Before getting started**\n", "---\n", "**Data Format**\n", "\n", @@ -125,7 +125,7 @@ "id": "AdN8B91xZO0x" }, "source": [ - "#**1. Install fnet and dependencies**\n", + "# **1. Install fnet and dependencies**\n", "---\n", "Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n", "\n", @@ -156,12 +156,11 @@ "source": [ "#@markdown ##Install fnet and dependencies\n", "\n", - "!pip install fpdf2\n", - "!pip install -U scipy==1.2.0\n", - "!pip install scikit-image==0.16.2\n", - "!pip install tifffile==2019.7.26\n", - "!pip install matplotlib==2.2.3\n", - "!pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113\n", + "!pip install -q fpdf2\n", + "!pip install -q scikit-image\n", + "!pip install -q tifffile\n", + "!pip install -q matplotlib\n", + "!pip install -q torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113\n", "\n", "#Force session restart\n", "exit(0)" @@ -251,7 +250,6 @@ "before = [str(m) for m in sys.modules]\n", "\n", "# !pip install fpdf\n", - "# !pip install -U scipy==1.2.0\n", "# !pip install tifffile==2019.7.26\n", "# !pip install matplotlib==2.2.3\n", "#@markdown ##Load key dependencies\n", @@ -283,7 +281,7 @@ "print(\"Tensorflow enabled.\")\n", "\n", "#clone fnet from github to colab\n", - "!git clone -b release_1 --single-branch https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n", + "!git clone -b release_1 --single-branch https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install -q .\n", "\n", "from skimage import io\n", "from matplotlib import pyplot as plt\n", @@ -414,10 +412,7 @@ "#Change the default dataset type in the training module to .tif\n", "replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n", "\n", - "\n", - "\n", "#2D \n", - "\n", "replace(\"/content/pytorch_fnet/train_model.py\",\"default=[32, 64, 64]\",\"default=[128, 128]\")\n", "replace(\"/content/pytorch_fnet/train_model.py\",\"'--nn_module', default='fnet_nn_3d'\",\"'--nn_module', default='fnet_nn_2d'\")\n", "\n", @@ -478,7 +473,7 @@ " #Main Packages\n", " main_packages = ''\n", " version_numbers = []\n", - " for name in ['tensorflow','numpy','torch','scipy']:\n", + " for name in ['tensorflow','numpy','torch']:\n", " find_name=all_packages.find(name)\n", " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", " #Version numbers only here:\n", @@ -494,9 +489,9 @@ " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n", " dataset_size = len(os.listdir(Training_source))\n", "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", + " text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", "\n", - " #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n", + " #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n", "\n", " #if Use_pretrained_model:\n", " # text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), pytorch (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'. The model was trained from the pretrained model: '+pretrained_model_path+'.'\n", @@ -960,8 +955,7 @@ "#Here, we edit this file to include the desired parameters\n", "\n", "#1. Add permissions to train_model.sh\n", - "os.chdir(\"/content/pytorch_fnet/scripts\")\n", - "!chmod u+x train_model.sh\n", + "!chmod u+x /content/pytorch_fnet/scripts/train_model.sh\n", "\n", "#2. Select parameters\n", "steps = 1000#@param {type:\"number\"}\n", @@ -973,18 +967,18 @@ "\n", "#3. Insert the above values into train_model.sh\n", "!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n", - "!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n", - "!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n", - "!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n", - "!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n", - "!sed -i \"s/LR=.*/LR=$learning_rate/g\" train_model.sh #change the learning rate\n", - "!sed -i \"s/PATCH_SIZE=.*/PATCH_SIZE=$patch_size/g\" train_model.sh #change the patch size\n", + "!sed -i \"s/1:-.*/1:-$model_name_x/g\" /content/pytorch_fnet/scripts/train_model.sh #change the dataset to be trained with\n", + "!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" /content/pytorch_fnet/scripts/train_model.sh #change the number of training iterations (steps)\n", + "!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" /content/pytorch_fnet/scripts/train_model.sh #change the number of training images\n", + "!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" /content/pytorch_fnet/scripts/train_model.sh #change the batch size\n", + "!sed -i \"s/LR=.*/LR=$learning_rate/g\" /content/pytorch_fnet/scripts/train_model.sh #change the learning rate\n", + "!sed -i \"s/PATCH_SIZE=.*/PATCH_SIZE=$patch_size/g\" /content/pytorch_fnet/scripts/train_model.sh #change the patch size\n", "\n", "\n", "\n", - "!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' train_model.sh\n", - "!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' train_model.sh\n", - "!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' train_model.sh\n", + "!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' /content/pytorch_fnet/scripts/train_model.sh\n", + "!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' /content/pytorch_fnet/scripts/train_model.sh\n", + "!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' /content/pytorch_fnet/scripts/train_model.sh\n", "\n", "#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n", "\n", @@ -999,7 +993,6 @@ "random_choice=random.choice(os.listdir(Training_source))\n", "x = io.imread(Training_source+\"/\"+random_choice)\n", "\n", - "os.chdir(Training_target)\n", "y = io.imread(Training_target+\"/\"+random_choice)\n", "\n", "f=plt.figure(figsize=(16,8))\n", @@ -1202,9 +1195,8 @@ " # else:\n", " # number_of_images = len(aug_source)\n", "\n", - " os.chdir(\"/content/pytorch_fnet/scripts\")\n", - " !chmod u+x train_model.sh\n", - " !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n", + " !chmod u+x /content/pytorch_fnet/scripts/train_model.sh\n", + " !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" /content/pytorch_fnet/scripts/train_model.sh #change the number of training images\n", "\n", " print(\"Done\")\n", "if not Use_Data_augmentation:\n", @@ -1262,11 +1254,7 @@ "if os.path.exists(model_path+'/'+model_name):\n", " shutil.rmtree(model_path+'/'+model_name)\n", " print(bcolors.WARNING +\"!! Existing model \"+model_name+\" was found and overwritten!!\")\n", - "os.mkdir(model_path+'/'+model_name)\n", - "\n", - "#os.chdir(model_path)\n", - "# source = os.listdir(Training_source)\n", - "# target = os.listdir(Training_target)\n", + "os.makedirs(model_path+'/'+model_name, exist_ok=True)\n", "\n", "if Use_Data_augmentation == True:\n", "\n", @@ -1317,10 +1305,9 @@ "outputs": [], "source": [ "#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n", - "os.chdir(\"/content/pytorch_fnet/scripts\")\n", "number_of_images = 50#@param{type:\"number\"}\n", - "!chmod u+x train_model.sh\n", - "!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images" + "!chmod u+x /content/pytorch_fnet/scripts/train_model.sh\n", + "!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" /content/pytorch_fnet/scripts/train_model.sh #change the number of training images" ] }, { @@ -1337,7 +1324,6 @@ "pdf_export(augmentation = Use_Data_augmentation)\n", "start = time.time()\n", "\n", - "os.chdir('/content')\n", "add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n", "\n", "### TRAIN THE MODEL ###\n", @@ -1461,25 +1447,22 @@ "batch_size = 64 #@param {type:\"number\"}\n", "\n", "#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n", - "os.chdir(\"/content/pytorch_fnet/scripts\")\n", - "number_of_images = 100#@param{type:\"number\"}\n", - "!chmod u+x train_model.sh\n", - "!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n", + "number_of_images = 100 #@param{type:\"number\"}\n", + "!chmod u+x /content/pytorch_fnet/scripts/train_model.sh\n", + "!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" /content/pytorch_fnet/scripts/train_model.sh #change the number of training images\n", "\n", "# Editing the train.sh script file #\n", "\n", - "os.chdir('/content/pytorch_fnet/scripts')\n", - "\n", "#Change the train_model.sh file to include chosen dataset\n", "!chmod u+x ./train_model.sh\n", - "!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\n", - "!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\n", - "!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n", - "!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n", + "!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" /content/pytorch_fnet/scripts/train_model.sh\n", + "!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" /content/pytorch_fnet/scripts/train_model.sh #Use the whole training dataset for training\n", + "!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" /content/pytorch_fnet/scripts/train_model.sh #change the number of training images\n", + "!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" /content/pytorch_fnet/scripts/train_model.sh #change the batch size\n", "\n", "!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' train_model.sh\n", - "!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' train_model.sh\n", - "!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' train_model.sh\n", + "!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' /content/pytorch_fnet/scripts/train_model.sh\n", + "!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' /content/pytorch_fnet/scripts/train_model.sh\n", "\n", "replace('/content/pytorch_fnet/scripts/train_model.sh','python scripts', '#python scripts')\n", "replace('/content/pytorch_fnet/scripts/train_model.sh','python train_model.py', 'python /content/pytorch_fnet/train_model.py')\n", @@ -1497,10 +1480,9 @@ "\n", "#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n", "new_steps = previous_steps + add_steps -1\n", - "os.chdir('/content/pytorch_fnet/scripts')\n", "\n", "#Edit train_model.sh file to include new total number of training epochs\n", - "!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh\n", + "!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" /content/pytorch_fnet/scripts/train_model.sh\n", "\n", "# Display example data #\n", "\n", @@ -1512,7 +1494,6 @@ "# Image_Z = x.shape[0]\n", "# mid_plane = int(Image_Z / 2)+1\n", "\n", - "#os.chdir(Training_target)\n", "y = io.imread(ExampleTarget)\n", "\n", "f=plt.figure(figsize=(16,8))\n", @@ -1542,14 +1523,11 @@ "#@markdown ##4.2. Start re-training model\n", "# !pip install tifffile==2019.7.26\n", "\n", - "os.chdir('/content/pytorch_fnet/fnet')\n", - "\n", "add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n", "\n", "#Here we retrain the model on the chosen dataset.\n", - "os.chdir('/content/pytorch_fnet/')\n", - "!chmod u+x ./scripts/train_model.sh\n", - "!./scripts/train_model.sh $Pretrained_model_name 0\n", + "!chmod u+x /content/pytorch_fnet/scripts/train_model.sh\n", + "!/content/pytorch_fnet/scripts/train_model.sh $Pretrained_model_name 0\n", "\n", "# Displaying the time elapsed for training\n", "dt = time.time() - start\n", @@ -1772,13 +1750,12 @@ "### Editing the predict.sh script file ###\n", "\n", "# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n", - "os.chdir('/content/pytorch_fnet/')\n", "# !chmod u+x ./scripts/predict.sh\n", "# !sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" ./scripts/predict.sh\n", - "!chmod u+x ./scripts/predict_2d.sh\n", - "!sed -i \"1,21!d\" ./scripts/predict_2d.sh\n", + "!chmod u+x /content/pytorch_fnet/scripts/predict_2d.sh\n", + "!sed -i \"1,21!d\" /content/pytorch_fnet/scripts/predict_2d.sh\n", "\n", - "!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" ./scripts/predict_2d.sh\n", + "!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/pytorch_fnet/scripts/predict_2d.sh\n", "\n", "#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n", "#!sed -i \"s/in test.*/in test/g\" ./scripts/predict.sh\n", @@ -1786,10 +1763,10 @@ "# !if ! grep class_dataset ./scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/pytorch_fnet/scripts/predict.sh; fi\n", "# !if grep CziDataset ./scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' ./scripts/predict.sh; fi \n", "\n", - "!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_QC_model_path/g\" ./scripts/predict_2d.sh\n", - "!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_QC_model_path_csv\\ \\\\\\/g\" ./scripts/predict_2d.sh\n", - "!sed -i \"s/path_save_dir.*/path_save_dir $new_full_QC_model_path\\/QualityControl\\/Predictions\\ \\\\\\/g\" ./scripts/predict_2d.sh\n", - "!sed -i \"s/N_IMAGES=.*/N_IMAGES=$qc_images/g\" ./scripts/predict_2d.sh\n", + "!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_QC_model_path/g\" /content/pytorch_fnet/scripts/predict_2d.sh\n", + "!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_QC_model_path_csv\\ \\\\\\/g\" /content/pytorch_fnet/scripts/predict_2d.sh\n", + "!sed -i \"s/path_save_dir.*/path_save_dir $new_full_QC_model_path\\/QualityControl\\/Predictions\\ \\\\\\/g\" /content/pytorch_fnet/scripts/predict_2d.sh\n", + "!sed -i \"s/N_IMAGES=.*/N_IMAGES=$qc_images/g\" /content/pytorch_fnet/scripts/predict_2d.sh\n", "\n", "### Create a path csv file for prediction (QC)###\n", "\n", @@ -2131,8 +2108,7 @@ "new_Results_folder_path = convert_to_script_compatible_path(Results_folder)\n", "\n", "# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n", - "os.chdir('/content/pytorch_fnet/')\n", - "!chmod u+x ./scripts/predict_2d.sh\n", + "!chmod u+x /content/pytorch_fnet/scripts/predict_2d.sh\n", "\n", "### Editing the predict.sh script file ###\n", "\n", @@ -2141,17 +2117,17 @@ "# !if grep CziDataset ./scripts/predict_2d.sh;then sed -i 's/CziDataset/TiffDataset/' /content/pytorch_fnet/scripts/predict.sh; fi \n", "\n", "# We allow the maximum number of images to be processed to be higher, i.e. 1000.\n", - "!sed -i \"s/N_IMAGES=.*/N_IMAGES=$data_files/g\" ./scripts/predict_2d.sh\n", - "!sed -i \"s/1:-.*/1:-$Prediction_model_name_x/g\" ./scripts/predict_2d.sh\n", + "!sed -i \"s/N_IMAGES=.*/N_IMAGES=$data_files/g\" /content/pytorch_fnet/scripts/predict_2d.sh\n", + "!sed -i \"s/1:-.*/1:-$Prediction_model_name_x/g\" /content/pytorch_fnet/scripts/predict_2d.sh\n", "\n", "#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n", - "!sed -i \"s/in test.*/in test/g\" ./scripts/predict.sh\n", - "!sed -i \"1,21!d\" ./scripts/predict_2d.sh\n", + "!sed -i \"s/in test.*/in test/g\" /content/pytorch_fnet/scripts/predict.sh\n", + "!sed -i \"1,21!d\" /content/pytorch_fnet/scripts/predict_2d.sh\n", "\n", "#We change the directories in the predict.sh file to our needed paths\n", - "!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_Prediction_model_path/g\" ./scripts/predict_2d.sh\n", - "!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_Prediction_model_path_csv\\ \\\\\\/g\" ./scripts/predict_2d.sh\n", - "!sed -i \"s/path_save_dir.*/path_save_dir $new_Results_folder_path\\/Predictions\\ \\\\\\/g\" ./scripts/predict_2d.sh\n", + "!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_Prediction_model_path/g\" /content/pytorch_fnet/scripts/predict_2d.sh\n", + "!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_Prediction_model_path_csv\\ \\\\\\/g\" /content/pytorch_fnet/scripts/predict_2d.sh\n", + "!sed -i \"s/path_save_dir.*/path_save_dir $new_Results_folder_path\\/Predictions\\ \\\\\\/g\" /content/pytorch_fnet/scripts/predict_2d.sh\n", "\n", "# Changing the GPU ID seems to help reduce errors\n", "# replace('/content/pytorch_fnet/scripts/predict.sh','${GPU_IDS}','0')\n", @@ -2217,14 +2193,12 @@ "#@markdown ###Select the image would you like to view below\n", "\n", "def show_image(file=os.listdir(Data_folder)):\n", - " os.chdir(Results_folder)\n", - "\n", - "#source_image = io.imread(test_signal[0])\n", - " source_image = io.imread(os.path.join(Data_folder,file))\n", - " prediction_image = io.imread(os.path.join(Results_folder,'Predictions/predicted_'+file))\n", + " #source_image = io.imread(test_signal[0])\n", + " source_image = io.imread(os.path.join(Results_folder,Data_folder,file))\n", + " prediction_image = io.imread(os.path.join(Results_folder,Results_folder,'Predictions/predicted_'+file))\n", " prediction_image = np.squeeze(prediction_image, axis=(0,))\n", "\n", - "#Create the figure\n", + " #Create the figure\n", " fig = plt.figure(figsize=(10,20))\n", "\n", " #Setting up colours\n", diff --git a/Colab_notebooks/fnet_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/fnet_3D_ZeroCostDL4Mic.ipynb index aa79fd7c..adc3da9e 100644 --- a/Colab_notebooks/fnet_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/fnet_3D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n","**Data Format**\n","\n"," **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["#**1. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"GgmEMSOUybyu"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"bGu_k66ZxoJW"},"outputs":[],"source":["#@markdown ##Install fnet and dependencies\n","!pip install fpdf\n","#clone fnet from github to colab\n","!git clone -b release_1 --single-branch https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n","!pip install -U scipy==1.2.0\n","!pip install matplotlib==2.2.3\n","!pip install tifffile==2019.7.26\n","# !pip install --no-cache-dir tifffile==2019.7.26 \n","#Force session restart\n","exit(0)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"_j2XyI76yhtT"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"hKXc0D11y6q8"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"fq21zJVFNASx"},"outputs":[],"source":["Notebook_version = '1.13.1'\n","Network = 'fnet (3D)'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path)\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item)\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Load key dependencies\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import numpy as np\n","import shutil\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","from datetime import datetime\n","from astropy.visualization import simple_norm\n","import time\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","import matplotlib as mpl\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","def replace(file_path, pattern, subst):\n"," \"\"\"Function replaces a pattern in a .py file with a new pattern.\"\"\"\n"," \n"," \"\"\"Parameters:\n"," -file_path (string): path to the file to be changed.\n"," -pattern (string): pattern to be replaced. Make sure this is as unique as possible.\n"," -subst (string): new pattern. \"\"\"\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def add_insert(filepath,line_number,insertion,append):\n"," \"\"\"Function which inserts the a line into a document.\"\"\"\n"," \n"," \"\"\"Parameters:\n"," -filepath (string): path to the file which needs to be edited.\n"," -line (integer): Where to insert the new line. In the file, this line is ideally an empty one.\n"," -insertion (string): The line to be inserted. If it already exists it will not be added again.\n"," -append (string): If anything additional needs to be appended to the line, use this. Otherwise, leave as \"\" \"\"\"\n"," \n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," if append != \"\":\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","\n","def convert_to_script_compatible_path(original_path):\n"," \"\"\"Function which converts 'original_path' into a compatible format 'new_full_path' with the fnet .sh files \"\"\"\n"," new_full_path = \"\"\n"," for s in original_path:\n"," if s=='/':\n"," new_full_path += '\\/'\n"," else:\n"," new_full_path += s\n","\n"," return new_full_path\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","#Change the default dataset type in the training module to .tif\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","print(\"-------------------\")\n","print(\"Libraries installed\")\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free Prediction (fnet)'\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch','scipy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," #if Use_pretrained_model:\n"," # text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), pytorch (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'. The model was trained from the pretrained model: '+pretrained_model_path+'.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.ln(1)\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," # if Use_Default_Advanced_Parameters:\n"," # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\n"," \"\"\".format(percentage_validation,steps,batch_size)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," pdf.ln(1)\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Fnet.png').shape\n"," pdf.image('/content/TrainingDataExample_Fnet.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," pdf.ln(1)\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(1)\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n"," pdf.ln(1)\n","\n"," if trained:\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," else:\n"," pdf.output('/content/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," \n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free prediction (fnet)'\n"," #model_name = os.path.basename(QC_model_folder)\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," pdf.ln(1)\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," if os.path.exists(full_QC_model_path+'/QualityControl/lossCurvePlots.png'):\n"," exp_size = io.imread(full_QC_model_path+'/QualityControl/lossCurvePlots.png').shape\n"," pdf.image(full_QC_model_path+'/QualityControl/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/QualityControl/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/QualityControl/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/QualityControl/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," PSNR_PvsGT = header[4]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,NRMSE_PvsGT,PSNR_PvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," PSNR_PvsGT = row[4]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(PSNR_PvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}
{0}{1}{2}{3}{4}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," pdf.ln(1)\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(1)\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n"," pdf.ln(1)\n","\n"," pdf.output(full_QC_model_path+'/QualityControl/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"zCvebubeSaGY"},"outputs":[],"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"outputs":[],"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"VM8YvXMLzXyA"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**\n","\n"," **Training Parameters**\n","\n"," **`percentage validation`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**\n","\n","**Note 2: If you only need to retrain your model after a time-out, skip this cell and go straight to section 4.2. Just make sure your training datasets are still in their original folders.**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"ewpNJ_I0Mv47"},"outputs":[],"source":["#@markdown ###Datasets\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","### Choosing and editing the path names ###\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","full_model_path = model_path+'/'+model_name\n","\n","new_full_model_path = convert_to_script_compatible_path(full_model_path)\n","new_full_model_path_csv = new_full_model_path+'\\/'+model_name+'\\.csv'\n","new_full_model_path_val_csv = new_full_model_path+'\\/'+model_name+'_val\\.csv'\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", skip this cell and instead load \"+model_name+\" as Pretrained_model_folder in section 4.2\")\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","\n","### Edit the train.sh script file and train.py file ###\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","with open('/content/pytorch_fnet/scripts/train_model.sh','r') as scriptfile:\n"," lines = scriptfile.readlines()\n"," if 'PATH_DATASET_VAL_CSV' not in lines:\n"," insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",10,insert,\"\")\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",22,'\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}',\"\")\n"," \n","#Clear the White space from train.sh\n","with open('/content/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/pytorch_fnet/scripts/train_model_temp.sh','/content/pytorch_fnet/scripts/train_model.sh')\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","os.chdir(\"/content/pytorch_fnet/scripts\")\n","!chmod u+x train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","\n","source = os.listdir(Training_source)\n","target = os.listdir(Training_target)\n","number_of_images = len(source[:-round(len(source)*(percentage_validation/100))])\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' train_model.sh\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python scripts', '#python scripts')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python train_model.py', 'python /content/pytorch_fnet/train_model.py')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','PATH_DATASET_ALL_CSV','#PATH_DATASET_ALL_CSV')\n","#No Augmentation by default\n","Use_Data_augmentation = False\n","\n","#Load one randomly chosen training source file\n","random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Training_target)\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images).\n","\n","**Note 2: If you intend to use the retraining option at a later point, save the dataset in a folder in your Google Drive.**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"DMqWq5-AxnFU"},"outputs":[],"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target', flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target'):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n"," \n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the target folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," #Here, we ensure that there aren't too many images in the buffer.\n"," #The best value will depend on the size of the images and the assigned GPU.\n"," #If too many images are loaded to the buffer the notebook will terminate the training as the RAM limit will be exceeded.\n"," if len(os.listdir(Saving_path+'/augmented_source'))>100:\n"," number_of_images = 100\n"," else:\n"," number_of_images = len(os.listdir(Saving_path+'/augmented_source'))\n","\n"," os.chdir(\"/content/pytorch_fnet/scripts\")\n"," !chmod u+x train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"Nyf9ndiS7sL9"},"source":["#**4. Train the network**\n","---\n","\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\n","\n","\n","###**Choose one of the options to train fnet**.\n","\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\n","\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\n","\n"," **Carefully read the options before starting training.**"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"P9OJ0nlI71Rc"},"source":["##**4.2. Start Training**\n","---\n","\n","####Play the cell below to start training. \n","\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3).\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"MQvrHFVcJ6VT"},"outputs":[],"source":["#@markdown ##Create the dataset files for training\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," print(bcolors.WARNING +\"!! Existing model \"+model_name+\" was found and overwritten!!\")\n","os.mkdir(model_path+'/'+model_name)\n","\n","#os.chdir(model_path)\n","# source = os.listdir(Training_source)\n","# target = os.listdir(Training_target)\n","\n","if Use_Data_augmentation == True:\n","\n"," aug_source = os.listdir(Saving_path+'/augmented_source')\n"," aug_val_files = aug_source[-round(len(aug_source)*(percentage_validation/100)):]\n"," aug_source_files = aug_source[:-round(len(aug_source)*(percentage_validation/100))]\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(aug_val_files)):\n"," writer.writerow([Saving_path+'/augmented_source/'+aug_val_files[i],Saving_path+\"/augmented_target/\"+aug_val_files[i]])\n","\n"," with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(aug_source_files)):\n"," writer.writerow([Saving_path+'/augmented_source/'+aug_source_files[i],Saving_path+'/augmented_target/'+aug_source_files[i]])\n","\n","else:\n"," #Here we define the random set of training files to be used for validation\n"," val_files = source[-round(len(source)*(percentage_validation/100)):]\n"," source_files = source[:-round(len(source)*(percentage_validation/100))]\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_files)):\n"," writer.writerow([Training_source+'/'+val_files[i],Training_target+'/'+val_files[i]])\n","\n","\n"," with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source_files)):\n"," writer.writerow([Training_source+\"/\"+source_files[i],Training_target+\"/\"+source_files[i]])"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"X8YHeSGr76je"},"outputs":[],"source":["#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n","os.chdir(\"/content/pytorch_fnet/scripts\")\n","number_of_images = 10#@param{type:\"number\"}\n","!chmod u+x train_model.sh\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"7Ofm-71T8ABX"},"outputs":[],"source":["#@markdown ##Start training\n","pdf_export(augmentation = Use_Data_augmentation)\n","start = time.time()\n","\n","#Here we import an additional module to the functions.py file to run it without errors.\n","\n","os.chdir('/content')\n","add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n","\n","### TRAIN THE MODEL ###\n","\n","print('Let''s start the training!')\n","#Here we start the training\n","!/content/pytorch_fnet/scripts/train_model.sh $model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"bOdyjxWV8IrO"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"-JxxMmVr8Tw-"},"source":["##**4.2. Training from a previously saved model**\n","---\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on. Make sure your training datasets are in the same location as when you originally trained the model.**\n","\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"iDIgosht8U7F"},"outputs":[],"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\n","#@markdown Enter the paths of the datasets you want to continue training on.\n","\n","#Here we repeat steps already used above in case the notebook needs to be restarted for this cell.\n","#We need to add a new line to the train.sh file\n","with open(\"/content/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","with open('/content/pytorch_fnet/scripts/train_model.sh','r') as scriptfile:\n"," lines = scriptfile.readlines()\n"," if 'PATH_DATASET_VAL_CSV' not in lines:\n"," insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",10,insert,\"\")\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",22,'\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}',\"\")\n","\n","\n","#Clear the White space from train.sh\n","with open('/content/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/pytorch_fnet/scripts/train_model_temp.sh','/content/pytorch_fnet/scripts/train_model.sh')\n","\n","\n","#Change checkpoints\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","### Choosing and editing the path names ###\n","\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\n","\n","full_model_path = Pretrained_model_path+'/'+Pretrained_model_name\n","\n","new_full_model_path = convert_to_script_compatible_path(full_model_path)\n","new_full_model_path_csv = new_full_model_path+'\\/'+Pretrained_model_name+'\\.csv'\n","new_full_model_path_val_csv = new_full_model_path+'\\/'+Pretrained_model_name+'_val\\.csv'\n","\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\n","\n","#We get the example data and the number of images from the csv path file#\n","\n","with open(full_model_path+'/'+Pretrained_model_name+'.csv') as csvfile:\n"," csvreader = csv.reader(csvfile)\n"," header = next(csvreader)\n"," number_of_images = 0\n"," for line in csvreader:\n"," ExampleSource = line[0]\n"," ExampleTarget = line[1]\n"," number_of_images += 1\n","\n","with open(full_model_path+'/'+Pretrained_model_name+'_val.csv') as csvfile:\n"," csvreader = csv.reader(csvfile)\n"," header = next(csvreader)\n"," number_of_val_images = 0\n"," for line in csvreader:\n"," number_of_val_images += 1\n","\n","#Batch Size\n","\n","batch_size = 4 #@param {type:\"number\"}\n","\n","# Editing the train.sh script file #\n","\n","os.chdir('/content/pytorch_fnet/scripts')\n","\n","#Change the train_model.sh file to include chosen dataset\n","!chmod u+x ./train_model.sh\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' train_model.sh\n","\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python scripts', '#python scripts')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python train_model.py', 'python /content/pytorch_fnet/train_model.py')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','PATH_DATASET_ALL_CSV','#PATH_DATASET_ALL_CSV')\n","\n","# Find the number of steps to add and then add #\n","with open(Pretrained_model_folder+'/losses.csv') as f:\n"," previous_steps = sum(1 for line in f)\n","print('continuing training after step '+str(previous_steps-1))\n","\n","print('To start re-training play section 4.2. below')\n","\n","#@markdown For how many additional steps do you want to train the model?\n","add_steps = 150#@param {type:\"number\"}\n","\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n","new_steps = previous_steps + add_steps -1\n","os.chdir('/content/pytorch_fnet/scripts')\n","\n","#Edit train_model.sh file to include new total number of training epochs\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh\n","\n","# Display example data #\n","\n","#Load one randomly chosen training source file\n","#random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(ExampleSource)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","#os.chdir(Training_target)\n","y = io.imread(ExampleTarget)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"h1INk9nRE15L"},"outputs":[],"source":["#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n","os.chdir(\"/content/pytorch_fnet/scripts\")\n","number_of_images = 10#@param{type:\"number\"}\n","!chmod u+x train_model.sh\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"5IXdFqhM8gO2"},"outputs":[],"source":["Use_Data_augmentation = False \n","start = time.time()\n","\n","#@markdown ##4.2. Start re-training model\n","\n","os.chdir('/content/pytorch_fnet/fnet')\n","\n","add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n","\n","#Here we retrain the model on the chosen dataset.\n","os.chdir('/content/pytorch_fnet/')\n","!chmod u+x ./scripts/train_model.sh\n","!./scripts/train_model.sh $Pretrained_model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Here, we redefine the variable names for the pdf export\n","percentage_validation = round((number_of_val_images/(number_of_images+number_of_val_images))*100)\n","steps = new_steps\n","model_name = Pretrained_model_name\n","model_path = Pretrained_model_path\n","Training_source = os.path.dirname(ExampleSource)\n","Training_target = os.path.dirname(ExampleTarget)\n","#Create a pdf document with training summary\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"eAJzMwPA6tlH"},"outputs":[],"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"vMzSP50kMv5p"},"outputs":[],"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/QualityControl/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n","plt.show()\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n","\n","**Note 2:** If you get an 'sequence argument must have length equal to input rank' error, you may need to reshape your images from [z, x, y, c] or [c,z,x,y] to [z,x,y] by squeezing out the channel dimension, e.g. using numpy.squeeze. "]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"w90MdriMxhjD"},"outputs":[],"source":["#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","### Choosing and editing the path names ###\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/QualityControl/Predictions\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/QualityControl/Predictions\")\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","new_full_QC_model_path = convert_to_script_compatible_path(full_QC_model_path)\n","new_full_QC_model_path_dataset = new_full_QC_model_path+'\\${DATASET}'\n","new_full_QC_model_path_csv = new_full_QC_model_path+'\\/QualityControl\\/qc\\${TEST_OR_TRAIN}\\.csv'# Get the name of the folder the test data is in\n","\n","\n","### Editing the predict.sh script file ###\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/pytorch_fnet/')\n","!chmod u+x ./scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" ./scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" ./scripts/predict.sh\n","\n","!if ! grep class_dataset ./scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/pytorch_fnet/scripts/predict.sh; fi\n","!if grep CziDataset ./scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' ./scripts/predict.sh; fi \n","\n","!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_QC_model_path/g\" ./scripts/predict.sh\n","!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_QC_model_path_csv\\ \\\\\\/g\" ./scripts/predict.sh\n","!sed -i \"s/path_save_dir.*/path_save_dir $new_full_QC_model_path\\/QualityControl\\/Predictions\\ \\\\\\/g\" ./scripts/predict.sh\n","\n","\n","### Create a path csv file for prediction (QC)###\n","\n","#Here we create a qctest.csv to locate the files used for QC\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","\n","with open(full_QC_model_path+'/QualityControl/qctest.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," writer.writerow([Source_QC_folder+'/'+test_signal[i],Target_QC_folder+'/'+test_signal[i]])\n","\n","### RUN THE PREDICTION ###\n","!/content/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","### Save the results ###\n","QC_results_files = os.listdir(full_QC_model_path+'/QualityControl/Predictions')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," if os.path.isdir(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]):\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/prediction_'+QC_model_name+'.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction/'+'Predicted_'+test_signal[i])\n"," if os.path.exists(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/signal.tiff'):\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Signal/'+test_signal[i])\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Target/'+test_signal[i])\n"," else:\n"," shutil.copyfile(Source_QC_folder+'/'+test_signal[i],QC_model_path+'/'+QC_model_name+'/QualityControl/Signal/'+test_signal[i])\n"," shutil.copyfile(Target_QC_folder+'/'+test_target[i],QC_model_path+'/'+QC_model_name+'/QualityControl/Target/'+test_signal[i])\n","\n","shutil.rmtree(full_QC_model_path+'/QualityControl/Predictions')\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","n_slices = img.shape[0]\n","z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/QualityControl/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," if len(test_GT_stack.shape) > 3:\n"," test_GT_stack = test_GT_stack.squeeze()\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n","\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","if len(img_GT.shape) > 3:\n"," img_GT = img_GT.squeeze()\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane],aspect='equal',cmap=cmap)\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/QualityControl/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"y2TD5p7MZrEb"},"outputs":[],"source":["#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.\n","#This is just in case you have already trained on a dataset with the same name\n","#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists(Results_folder+\"/Predictions\"):\n"," shutil.rmtree(Results_folder+\"/Predictions\")\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = True #@param{type:\"boolean\"}\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","### Choosing and editing the path names ###\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","if Use_the_current_trained_model:\n"," Prediction_model_folder = model_path+'/'+model_name\n","\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","Prediction_model_name_x = Prediction_model_name+\"}\"\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Convert the path variables into a compatible format with the script files #\n","# Prediction path conversion\n","new_full_Prediction_model_path = convert_to_script_compatible_path(full_Prediction_model_path)\n","new_full_Prediction_model_path_csv = new_full_Prediction_model_path+'\\${TEST_OR_TRAIN}\\.csv'# Get the name of the folder the test data is in\n","\n","# Result path conversion\n","new_Results_folder_path = convert_to_script_compatible_path(Results_folder)\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/pytorch_fnet/')\n","!chmod u+x ./scripts/predict.sh\n","\n","### Editing the predict.sh script file ###\n","\n","# Make sure the dataset type is set to .tif (debug note: could be changed at install in predict.py file?)\n","!if ! grep class_dataset ./scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/pytorch_fnet/scripts/predict.sh; fi\n","!if grep CziDataset ./scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/pytorch_fnet/scripts/predict.sh; fi \n","\n","# We allow the maximum number of images to be processed to be higher, i.e. 1000.\n","!sed -i \"s/N_IMAGES=.*/N_IMAGES=1000/g\" ./scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Prediction_model_name_x/g\" ./scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" ./scripts/predict.sh\n","\n","#We change the directories in the predict.sh file to our needed paths\n","!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_Prediction_model_path/g\" ./scripts/predict.sh\n","!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_Prediction_model_path_csv\\ \\\\\\/g\" ./scripts/predict.sh\n","!sed -i \"s/path_save_dir.*/path_save_dir $new_Results_folder_path\\/Predictions\\ \\\\\\/g\" ./scripts/predict.sh\n","\n","# Changing the GPU ID seems to help reduce errors\n","replace('/content/pytorch_fnet/scripts/predict.sh','${GPU_IDS}','0')\n","\n","# We get rid of the options of saving signals and targets. Here, we just want predictions.\n","insert_1 = ' --no_signal \\\\\\n'\n","insert_2 = ' --no_target \\\\\\n'\n","add_insert(\"/content/pytorch_fnet/scripts/predict.sh\",14,insert_1,\"\")\n","add_insert(\"/content/pytorch_fnet/scripts/predict.sh\",14,insert_2,\"\")\n","\n","### Create the path csv file for prediction ###\n","\n","#Here we create a new test.csv with the paths to the dataset we want to predict on.\n","test_signal = os.listdir(Data_folder)\n","with open(full_Prediction_model_path+'/test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," writer.writerow([Data_folder+\"/\"+test_signal[i],Data_folder+\"/\"+test_signal[i]])\n","\n","### WE RUN THE PREDICTION ###\n","start = time.time()\n","!/content/pytorch_fnet/scripts/predict.sh $Prediction_model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Rename the results appropriately\n","Results = os.listdir(Results_folder+'/Predictions')\n","for i in Results:\n"," if os.path.isdir(Results_folder+'/Predictions/'+i):\n"," shutil.copyfile(Results_folder+'/Predictions/'+i+'/'+os.listdir(Results_folder+'/Predictions/'+i)[0],Results_folder+'/Predictions/'+'predicted_'+test_signal[int(i)])\n"," \n","for i in Results:\n"," if os.path.isdir(Results_folder+'/Predictions/'+i):\n"," shutil.rmtree(Results_folder+'/Predictions/'+i)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"bFtArIjs9tS9"},"source":["##**6.2. Assess predicted output**\n","---\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"66-af3rO9vM4"},"outputs":[],"source":["#@markdown ###Select the slice would you like to view?\n","slice_number = 15#@param {type:\"number\"}\n","\n","def show_image(file=os.listdir(Data_folder)):\n"," os.chdir(Results_folder)\n","\n","#source_image = io.imread(test_signal[0])\n"," source_image = io.imread(os.path.join(Data_folder,file))\n"," prediction_image = io.imread(os.path.join(Results_folder,'Predictions/predicted_'+file))\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\n","\n","#Create the figure\n"," fig = plt.figure(figsize=(10,20))\n","\n"," #Setting up colours\n"," cmap = plt.cm.Greys\n","\n"," plt.subplot(1,2,1)\n"," print(prediction_image.shape)\n"," plt.imshow(source_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Source')\n"," plt.subplot(1,2,2)\n"," plt.imshow(prediction_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Prediction')\n","\n","interact(show_image, continuous_update=False);"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"89tlSWBC940z"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"uRcJEjslvTj2"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This version has an additional step before re-training in section 4.2. which allows to change the number of images loaded into buffer.\n","* An additional note is given for the QC step, indicating the shape of the image files.\n","* Existing model files are now overwritten in an additional section before the training cell, allowing errors to be corrected before the model folder is overwritten.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","* This version also now includes built-in version check and the version log that you're reading now."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using fnet!**"]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":["IkSguVy8Xv83","jWAz2i7RdxUV","gKDLkLWUd-YX","UvSlTaH14s3t"],"machine_shape":"hm","name":"fnet_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1SisekHpRSJ0QKHvDePqFe09lkklVytwI","timestamp":1622728423435},{"file_id":"12UsRdIQbcWQjYewI2wrcwIWfVxc6hOfc","timestamp":1620660071757},{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611063104553},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"}},"nbformat":4,"nbformat_minor":0} +{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n","**Data Format**\n","\n"," **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["#**1. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"GgmEMSOUybyu"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"bGu_k66ZxoJW"},"outputs":[],"source":["#@markdown ##Install fnet and dependencies\n","!pip install fpdf\n","#clone fnet from github to colab\n","!git clone -b release_1 --single-branch https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n","!pip install -U scipy==1.2.0\n","!pip install matplotlib==2.2.3\n","!pip install tifffile==2019.7.26\n","# !pip install --no-cache-dir tifffile==2019.7.26 \n","#Force session restart\n","exit(0)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"_j2XyI76yhtT"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"hKXc0D11y6q8"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"fq21zJVFNASx"},"outputs":[],"source":["Notebook_version = '1.13.1'\n","Network = 'fnet (3D)'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path)\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item)\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Load key dependencies\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import numpy as np\n","import shutil\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","from datetime import datetime\n","from astropy.visualization import simple_norm\n","import time\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","import matplotlib as mpl\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","def replace(file_path, pattern, subst):\n"," \"\"\"Function replaces a pattern in a .py file with a new pattern.\"\"\"\n"," \n"," \"\"\"Parameters:\n"," -file_path (string): path to the file to be changed.\n"," -pattern (string): pattern to be replaced. Make sure this is as unique as possible.\n"," -subst (string): new pattern. \"\"\"\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def add_insert(filepath,line_number,insertion,append):\n"," \"\"\"Function which inserts the a line into a document.\"\"\"\n"," \n"," \"\"\"Parameters:\n"," -filepath (string): path to the file which needs to be edited.\n"," -line (integer): Where to insert the new line. In the file, this line is ideally an empty one.\n"," -insertion (string): The line to be inserted. If it already exists it will not be added again.\n"," -append (string): If anything additional needs to be appended to the line, use this. Otherwise, leave as \"\" \"\"\"\n"," \n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," if append != \"\":\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","\n","def convert_to_script_compatible_path(original_path):\n"," \"\"\"Function which converts 'original_path' into a compatible format 'new_full_path' with the fnet .sh files \"\"\"\n"," new_full_path = \"\"\n"," for s in original_path:\n"," if s=='/':\n"," new_full_path += '\\/'\n"," else:\n"," new_full_path += s\n","\n"," return new_full_path\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","#Change the default dataset type in the training module to .tif\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","print(\"-------------------\")\n","print(\"Libraries installed\")\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free Prediction (fnet)'\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch','scipy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," #if Use_pretrained_model:\n"," # text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), pytorch (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'. The model was trained from the pretrained model: '+pretrained_model_path+'.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.ln(1)\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," # if Use_Default_Advanced_Parameters:\n"," # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\n"," \"\"\".format(percentage_validation,steps,batch_size)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," pdf.ln(1)\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Fnet.png').shape\n"," pdf.image('/content/TrainingDataExample_Fnet.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," pdf.ln(1)\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(1)\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n"," pdf.ln(1)\n","\n"," if trained:\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," else:\n"," pdf.output('/content/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," \n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free prediction (fnet)'\n"," #model_name = os.path.basename(QC_model_folder)\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," pdf.ln(1)\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," if os.path.exists(full_QC_model_path+'/QualityControl/lossCurvePlots.png'):\n"," exp_size = io.imread(full_QC_model_path+'/QualityControl/lossCurvePlots.png').shape\n"," pdf.image(full_QC_model_path+'/QualityControl/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/QualityControl/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/QualityControl/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/QualityControl/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," PSNR_PvsGT = header[4]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,NRMSE_PvsGT,PSNR_PvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," PSNR_PvsGT = row[4]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(PSNR_PvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}
{0}{1}{2}{3}{4}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," pdf.ln(1)\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(1)\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n"," pdf.ln(1)\n","\n"," pdf.output(full_QC_model_path+'/QualityControl/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"zCvebubeSaGY"},"outputs":[],"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"outputs":[],"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"VM8YvXMLzXyA"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**\n","\n"," **Training Parameters**\n","\n"," **`percentage validation`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**\n","\n","**Note 2: If you only need to retrain your model after a time-out, skip this cell and go straight to section 4.2. Just make sure your training datasets are still in their original folders.**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"ewpNJ_I0Mv47"},"outputs":[],"source":["#@markdown ###Datasets\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","### Choosing and editing the path names ###\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","full_model_path = model_path+'/'+model_name\n","\n","new_full_model_path = convert_to_script_compatible_path(full_model_path)\n","new_full_model_path_csv = new_full_model_path+'\\/'+model_name+'\\.csv'\n","new_full_model_path_val_csv = new_full_model_path+'\\/'+model_name+'_val\\.csv'\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", skip this cell and instead load \"+model_name+\" as Pretrained_model_folder in section 4.2\")\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","\n","### Edit the train.sh script file and train.py file ###\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","with open('/content/pytorch_fnet/scripts/train_model.sh','r') as scriptfile:\n"," lines = scriptfile.readlines()\n"," if 'PATH_DATASET_VAL_CSV' not in lines:\n"," insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",10,insert,\"\")\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",22,'\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}',\"\")\n"," \n","#Clear the White space from train.sh\n","with open('/content/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/pytorch_fnet/scripts/train_model_temp.sh','/content/pytorch_fnet/scripts/train_model.sh')\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","!chmod u+x /content/pytorch_fnet/scripts/train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","\n","source = os.listdir(Training_source)\n","target = os.listdir(Training_target)\n","number_of_images = len(source[:-round(len(source)*(percentage_validation/100))])\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' train_model.sh\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python scripts', '#python scripts')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python train_model.py', 'python /content/pytorch_fnet/train_model.py')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','PATH_DATASET_ALL_CSV','#PATH_DATASET_ALL_CSV')\n","#No Augmentation by default\n","Use_Data_augmentation = False\n","\n","#Load one randomly chosen training source file\n","random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images).\n","\n","**Note 2: If you intend to use the retraining option at a later point, save the dataset in a folder in your Google Drive.**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"DMqWq5-AxnFU"},"outputs":[],"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target', flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target'):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n"," \n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the target folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," #Here, we ensure that there aren't too many images in the buffer.\n"," #The best value will depend on the size of the images and the assigned GPU.\n"," #If too many images are loaded to the buffer the notebook will terminate the training as the RAM limit will be exceeded.\n"," if len(os.listdir(Saving_path+'/augmented_source'))>100:\n"," number_of_images = 100\n"," else:\n"," number_of_images = len(os.listdir(Saving_path+'/augmented_source'))\n","\n"," !chmod u+x /content/pytorch_fnet/scripts/train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" /content/pytorch_fnet/scripts/train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"Nyf9ndiS7sL9"},"source":["#**4. Train the network**\n","---\n","\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\n","\n","\n","###**Choose one of the options to train fnet**.\n","\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\n","\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\n","\n"," **Carefully read the options before starting training.**"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"P9OJ0nlI71Rc"},"source":["##**4.2. Start Training**\n","---\n","\n","####Play the cell below to start training. \n","\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3).\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"MQvrHFVcJ6VT"},"outputs":[],"source":["#@markdown ##Create the dataset files for training\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," print(bcolors.WARNING +\"!! Existing model \"+model_name+\" was found and overwritten!!\")\n","os.mkdir(model_path+'/'+model_name)\n","\n","if Use_Data_augmentation == True:\n","\n"," aug_source = os.listdir(Saving_path+'/augmented_source')\n"," aug_val_files = aug_source[-round(len(aug_source)*(percentage_validation/100)):]\n"," aug_source_files = aug_source[:-round(len(aug_source)*(percentage_validation/100))]\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(aug_val_files)):\n"," writer.writerow([Saving_path+'/augmented_source/'+aug_val_files[i],Saving_path+\"/augmented_target/\"+aug_val_files[i]])\n","\n"," with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(aug_source_files)):\n"," writer.writerow([Saving_path+'/augmented_source/'+aug_source_files[i],Saving_path+'/augmented_target/'+aug_source_files[i]])\n","\n","else:\n"," #Here we define the random set of training files to be used for validation\n"," val_files = source[-round(len(source)*(percentage_validation/100)):]\n"," source_files = source[:-round(len(source)*(percentage_validation/100))]\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_files)):\n"," writer.writerow([Training_source+'/'+val_files[i],Training_target+'/'+val_files[i]])\n","\n","\n"," with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source_files)):\n"," writer.writerow([Training_source+\"/\"+source_files[i],Training_target+\"/\"+source_files[i]])"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"X8YHeSGr76je"},"outputs":[],"source":["#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n","number_of_images = 10#@param{type:\"number\"}\n","!chmod u+x /content/pytorch_fnet/scripts/train_model.sh\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" /content/pytorch_fnet/scripts/train_model.sh #change the number of training images"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"7Ofm-71T8ABX"},"outputs":[],"source":["#@markdown ##Start training\n","pdf_export(augmentation = Use_Data_augmentation)\n","start = time.time()\n","\n","#Here we import an additional module to the functions.py file to run it without errors.\n","add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n","\n","### TRAIN THE MODEL ###\n","\n","print('Let''s start the training!')\n","#Here we start the training\n","!/content/pytorch_fnet/scripts/train_model.sh $model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"bOdyjxWV8IrO"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"-JxxMmVr8Tw-"},"source":["## **4.2. Training from a previously saved model**\n","---\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on. Make sure your training datasets are in the same location as when you originally trained the model.**\n","\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"iDIgosht8U7F"},"outputs":[],"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\n","#@markdown Enter the paths of the datasets you want to continue training on.\n","\n","#Here we repeat steps already used above in case the notebook needs to be restarted for this cell.\n","#We need to add a new line to the train.sh file\n","with open(\"/content/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","with open('/content/pytorch_fnet/scripts/train_model.sh','r') as scriptfile:\n"," lines = scriptfile.readlines()\n"," if 'PATH_DATASET_VAL_CSV' not in lines:\n"," insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",10,insert,\"\")\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",22,'\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}',\"\")\n","\n","\n","#Clear the White space from train.sh\n","with open('/content/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/pytorch_fnet/scripts/train_model_temp.sh','/content/pytorch_fnet/scripts/train_model.sh')\n","\n","\n","#Change checkpoints\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","### Choosing and editing the path names ###\n","\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\n","\n","full_model_path = Pretrained_model_path+'/'+Pretrained_model_name\n","\n","new_full_model_path = convert_to_script_compatible_path(full_model_path)\n","new_full_model_path_csv = new_full_model_path+'\\/'+Pretrained_model_name+'\\.csv'\n","new_full_model_path_val_csv = new_full_model_path+'\\/'+Pretrained_model_name+'_val\\.csv'\n","\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\n","\n","#We get the example data and the number of images from the csv path file#\n","\n","with open(full_model_path+'/'+Pretrained_model_name+'.csv') as csvfile:\n"," csvreader = csv.reader(csvfile)\n"," header = next(csvreader)\n"," number_of_images = 0\n"," for line in csvreader:\n"," ExampleSource = line[0]\n"," ExampleTarget = line[1]\n"," number_of_images += 1\n","\n","with open(full_model_path+'/'+Pretrained_model_name+'_val.csv') as csvfile:\n"," csvreader = csv.reader(csvfile)\n"," header = next(csvreader)\n"," number_of_val_images = 0\n"," for line in csvreader:\n"," number_of_val_images += 1\n","\n","#Batch Size\n","\n","batch_size = 4 #@param {type:\"number\"}\n","\n","# Editing the train.sh script file #\n","\n","#Change the train_model.sh file to include chosen dataset\n","!chmod u+x /content/pytorch_fnet/scripts/train_model.sh\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" /content/pytorch_fnet/scripts/train_model.sh\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" /content/pytorch_fnet/scripts/train_model.sh #Use the whole training dataset for training\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" /content/pytorch_fnet/scripts/train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" /content/pytorch_fnet/scripts/train_model.sh #change the batch size\n","\n","!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' /content/pytorch_fnet/scripts/train_model.sh\n","!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' /content/pytorch_fnet/scripts/train_model.sh\n","!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' /content/pytorch_fnet/scripts/train_model.sh\n","\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python scripts', '#python scripts')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python train_model.py', 'python /content/pytorch_fnet/train_model.py')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','PATH_DATASET_ALL_CSV','#PATH_DATASET_ALL_CSV')\n","\n","# Find the number of steps to add and then add #\n","with open(Pretrained_model_folder+'/losses.csv') as f:\n"," previous_steps = sum(1 for line in f)\n","print('continuing training after step '+str(previous_steps-1))\n","\n","print('To start re-training play section 4.2. below')\n","\n","#@markdown For how many additional steps do you want to train the model?\n","add_steps = 150#@param {type:\"number\"}\n","\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n","new_steps = previous_steps + add_steps -1\n","\n","#Edit train_model.sh file to include new total number of training epochs\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" /content/pytorch_fnet/scripts/train_model.sh\n","\n","# Display example data #\n","\n","#Load one randomly chosen training source file\n","#random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(ExampleSource)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","y = io.imread(ExampleTarget)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"h1INk9nRE15L"},"outputs":[],"source":["#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n","number_of_images = 10#@param{type:\"number\"}\n","!chmod u+x /content/pytorch_fnet/scripts/train_model.sh\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" /content/pytorch_fnet/scripts/train_model.sh #change the number of training images"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"5IXdFqhM8gO2"},"outputs":[],"source":["Use_Data_augmentation = False \n","start = time.time()\n","\n","#@markdown ##4.2. Start re-training model\n","add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n","\n","#Here we retrain the model on the chosen dataset.\n","!chmod u+x /content/pytorch_fnet//scripts/train_model.sh\n","!/content/pytorch_fnet//scripts/train_model.sh $Pretrained_model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Here, we redefine the variable names for the pdf export\n","percentage_validation = round((number_of_val_images/(number_of_images+number_of_val_images))*100)\n","steps = new_steps\n","model_name = Pretrained_model_name\n","model_path = Pretrained_model_path\n","Training_source = os.path.dirname(ExampleSource)\n","Training_target = os.path.dirname(ExampleTarget)\n","#Create a pdf document with training summary\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"eAJzMwPA6tlH"},"outputs":[],"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"vMzSP50kMv5p"},"outputs":[],"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/QualityControl/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n","plt.show()\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n","\n","**Note 2:** If you get an 'sequence argument must have length equal to input rank' error, you may need to reshape your images from [z, x, y, c] or [c,z,x,y] to [z,x,y] by squeezing out the channel dimension, e.g. using numpy.squeeze. "]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"w90MdriMxhjD"},"outputs":[],"source":["#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","### Choosing and editing the path names ###\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/QualityControl/Predictions\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/QualityControl/Predictions\")\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","new_full_QC_model_path = convert_to_script_compatible_path(full_QC_model_path)\n","new_full_QC_model_path_dataset = new_full_QC_model_path+'\\${DATASET}'\n","new_full_QC_model_path_csv = new_full_QC_model_path+'\\/QualityControl\\/qc\\${TEST_OR_TRAIN}\\.csv'# Get the name of the folder the test data is in\n","\n","\n","### Editing the predict.sh script file ###\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","!chmod u+x /content/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/pytorch_fnet/scripts/predict.sh\n","\n","!if ! grep class_dataset /content/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/pytorch_fnet/scripts/predict.sh; fi\n","!if grep CziDataset /content/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/pytorch_fnet/scripts/predict.sh; fi \n","\n","!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_QC_model_path/g\" /content/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_QC_model_path_csv\\ \\\\\\/g\" /content/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/path_save_dir.*/path_save_dir $new_full_QC_model_path\\/QualityControl\\/Predictions\\ \\\\\\/g\" /content/pytorch_fnet/scripts/predict.sh\n","\n","\n","### Create a path csv file for prediction (QC)###\n","\n","#Here we create a qctest.csv to locate the files used for QC\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","\n","with open(full_QC_model_path+'/QualityControl/qctest.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," writer.writerow([Source_QC_folder+'/'+test_signal[i],Target_QC_folder+'/'+test_signal[i]])\n","\n","### RUN THE PREDICTION ###\n","!/content/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","### Save the results ###\n","QC_results_files = os.listdir(full_QC_model_path+'/QualityControl/Predictions')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," if os.path.isdir(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]):\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/prediction_'+QC_model_name+'.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction/'+'Predicted_'+test_signal[i])\n"," if os.path.exists(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/signal.tiff'):\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Signal/'+test_signal[i])\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Target/'+test_signal[i])\n"," else:\n"," shutil.copyfile(Source_QC_folder+'/'+test_signal[i],QC_model_path+'/'+QC_model_name+'/QualityControl/Signal/'+test_signal[i])\n"," shutil.copyfile(Target_QC_folder+'/'+test_target[i],QC_model_path+'/'+QC_model_name+'/QualityControl/Target/'+test_signal[i])\n","\n","shutil.rmtree(full_QC_model_path+'/QualityControl/Predictions')\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","n_slices = img.shape[0]\n","z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/QualityControl/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," if len(test_GT_stack.shape) > 3:\n"," test_GT_stack = test_GT_stack.squeeze()\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n","\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","if len(img_GT.shape) > 3:\n"," img_GT = img_GT.squeeze()\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane],aspect='equal',cmap=cmap)\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/QualityControl/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"y2TD5p7MZrEb"},"outputs":[],"source":["#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.\n","#This is just in case you have already trained on a dataset with the same name\n","#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists(Results_folder+\"/Predictions\"):\n"," shutil.rmtree(Results_folder+\"/Predictions\")\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = True #@param{type:\"boolean\"}\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","### Choosing and editing the path names ###\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","if Use_the_current_trained_model:\n"," Prediction_model_folder = model_path+'/'+model_name\n","\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","Prediction_model_name_x = Prediction_model_name+\"}\"\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Convert the path variables into a compatible format with the script files #\n","# Prediction path conversion\n","new_full_Prediction_model_path = convert_to_script_compatible_path(full_Prediction_model_path)\n","new_full_Prediction_model_path_csv = new_full_Prediction_model_path+'\\${TEST_OR_TRAIN}\\.csv'# Get the name of the folder the test data is in\n","\n","# Result path conversion\n","new_Results_folder_path = convert_to_script_compatible_path(Results_folder)\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","!chmod u+x /content/pytorch_fnet/scripts/predict.sh\n","\n","### Editing the predict.sh script file ###\n","\n","# Make sure the dataset type is set to .tif (debug note: could be changed at install in predict.py file?)\n","!if ! grep class_dataset /content/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/pytorch_fnet/scripts/predict.sh; fi\n","!if grep CziDataset /content/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/pytorch_fnet/scripts/predict.sh; fi \n","\n","# We allow the maximum number of images to be processed to be higher, i.e. 1000.\n","!sed -i \"s/N_IMAGES=.*/N_IMAGES=1000/g\" /content/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Prediction_model_name_x/g\" /content/pytorch_fnet/scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/pytorch_fnet/scripts/predict.sh\n","\n","#We change the directories in the predict.sh file to our needed paths\n","!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_Prediction_model_path/g\" /content/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_Prediction_model_path_csv\\ \\\\\\/g\" /content/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/path_save_dir.*/path_save_dir $new_Results_folder_path\\/Predictions\\ \\\\\\/g\" /content/pytorch_fnet/scripts/predict.sh\n","\n","# Changing the GPU ID seems to help reduce errors\n","replace('/content/pytorch_fnet/scripts/predict.sh','${GPU_IDS}','0')\n","\n","# We get rid of the options of saving signals and targets. Here, we just want predictions.\n","insert_1 = ' --no_signal \\\\\\n'\n","insert_2 = ' --no_target \\\\\\n'\n","add_insert(\"/content/pytorch_fnet/scripts/predict.sh\",14,insert_1,\"\")\n","add_insert(\"/content/pytorch_fnet/scripts/predict.sh\",14,insert_2,\"\")\n","\n","### Create the path csv file for prediction ###\n","\n","#Here we create a new test.csv with the paths to the dataset we want to predict on.\n","test_signal = os.listdir(Data_folder)\n","with open(full_Prediction_model_path+'/test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," writer.writerow([Data_folder+\"/\"+test_signal[i],Data_folder+\"/\"+test_signal[i]])\n","\n","### WE RUN THE PREDICTION ###\n","start = time.time()\n","!/content/pytorch_fnet/scripts/predict.sh $Prediction_model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Rename the results appropriately\n","Results = os.listdir(Results_folder+'/Predictions')\n","for i in Results:\n"," if os.path.isdir(Results_folder+'/Predictions/'+i):\n"," shutil.copyfile(Results_folder+'/Predictions/'+i+'/'+os.listdir(Results_folder+'/Predictions/'+i)[0],Results_folder+'/Predictions/'+'predicted_'+test_signal[int(i)])\n"," \n","for i in Results:\n"," if os.path.isdir(Results_folder+'/Predictions/'+i):\n"," shutil.rmtree(Results_folder+'/Predictions/'+i)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"bFtArIjs9tS9"},"source":["## **6.2. Assess predicted output**\n","---\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"66-af3rO9vM4"},"outputs":[],"source":["#@markdown ###Select the slice would you like to view?\n","slice_number = 15#@param {type:\"number\"}\n","\n","def show_image(file=os.listdir(Data_folder)):\n"," #source_image = io.imread(test_signal[0])\n"," source_image = io.imread(os.path.join(Results_folder,Data_folder,file))\n"," prediction_image = io.imread(os.path.join(Results_folder,'Predictions/predicted_'+file))\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\n","\n"," #Create the figure\n"," fig = plt.figure(figsize=(10,20))\n","\n"," #Setting up colours\n"," cmap = plt.cm.Greys\n","\n"," plt.subplot(1,2,1)\n"," print(prediction_image.shape)\n"," plt.imshow(source_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Source')\n"," plt.subplot(1,2,2)\n"," plt.imshow(prediction_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Prediction')\n","\n","interact(show_image, continuous_update=False);"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"89tlSWBC940z"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"uRcJEjslvTj2"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This version has an additional step before re-training in section 4.2. which allows to change the number of images loaded into buffer.\n","* An additional note is given for the QC step, indicating the shape of the image files.\n","* Existing model files are now overwritten in an additional section before the training cell, allowing errors to be corrected before the model folder is overwritten.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","* This version also now includes built-in version check and the version log that you're reading now."]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using fnet!**"]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":["IkSguVy8Xv83","jWAz2i7RdxUV","gKDLkLWUd-YX","UvSlTaH14s3t"],"machine_shape":"hm","name":"fnet_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1SisekHpRSJ0QKHvDePqFe09lkklVytwI","timestamp":1622728423435},{"file_id":"12UsRdIQbcWQjYewI2wrcwIWfVxc6hOfc","timestamp":1620660071757},{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611063104553},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"}},"nbformat":4,"nbformat_minor":0} diff --git a/requirements_files/CARE_2D_requirements_simple.txt b/requirements_files/CARE_2D_requirements_simple.txt index 45cd6214..97b96bf6 100644 --- a/requirements_files/CARE_2D_requirements_simple.txt +++ b/requirements_files/CARE_2D_requirements_simple.txt @@ -1,37 +1,21 @@ -astropy==4.2 -cloudpickle==1.3.0 -csbdeep==0.6.1 -dask==2.12.0 -fpdf==1.7.2 -gast==0.3.3 +# Requirements for CARE_2D_ZeroCostDL4Mic.ipynb +Augmentor==0.2.12 +astropy==5.2.2 +csbdeep==0.7.4 +fpdf2==2.7.4 +future==0.18.3 google==2.0.3 -h5py==2.10.0 -httplib2==0.17.4 -image==1.5.33 -imageio==2.4.1 -joblib==1.0.1 -multiprocess==0.70.11.1 -numexpr==2.7.2 -numpy==1.19.5 -oauth2client==4.1.3 -pandas==1.1.5 -portpicker==1.3.1 -py==1.10.0 -pyasn1==0.4.8 -pydot==1.3.0 -rsa==4.7.2 -scipy==1.4.1 -six==1.15.0 -scikit-image==0.16.2 -scikit-learn==0.22.2.post1 -sortedcontainers==2.3.0 -tblib==1.7.0 -tensorboard==1.15.0 -tensorflow==1.15.2 -termcolor==1.1.0 -tifffile==2021.2.26 -toolz==0.11.1 -tqdm==4.41.1 -uritemplate==3.0.1 +matplotlib==3.7.1 +memory-profiler==0.61.0 +numexpr==2.8.4 +numpy==1.22.4 +pandas==1.5.3 +pathlib==1.0.1 +pip==23.1.2 +scikit-image==0.19.3 +scikit-learn==1.2.2 +scipy==1.10.1 +tensorflow==2.12.0 +tifffile==2023.7.18 +tqdm==4.65.0 wget==3.2 -wrapt==1.12.1 diff --git a/requirements_files/CARE_3D_requirements_simple.txt b/requirements_files/CARE_3D_requirements_simple.txt index 45cd6214..62c42e40 100644 --- a/requirements_files/CARE_3D_requirements_simple.txt +++ b/requirements_files/CARE_3D_requirements_simple.txt @@ -1,37 +1,21 @@ -astropy==4.2 -cloudpickle==1.3.0 -csbdeep==0.6.1 -dask==2.12.0 -fpdf==1.7.2 -gast==0.3.3 +# Requirements for CARE_3D_ZeroCostDL4Mic.ipynb +astropy==5.2.2 +csbdeep==0.7.4 +fpdf2==2.7.4 +future==0.18.3 google==2.0.3 -h5py==2.10.0 -httplib2==0.17.4 -image==1.5.33 -imageio==2.4.1 -joblib==1.0.1 -multiprocess==0.70.11.1 -numexpr==2.7.2 -numpy==1.19.5 -oauth2client==4.1.3 -pandas==1.1.5 -portpicker==1.3.1 -py==1.10.0 -pyasn1==0.4.8 -pydot==1.3.0 -rsa==4.7.2 -scipy==1.4.1 -six==1.15.0 -scikit-image==0.16.2 -scikit-learn==0.22.2.post1 -sortedcontainers==2.3.0 -tblib==1.7.0 -tensorboard==1.15.0 -tensorflow==1.15.2 -termcolor==1.1.0 -tifffile==2021.2.26 -toolz==0.11.1 -tqdm==4.41.1 -uritemplate==3.0.1 +ipywidgets==8.0.7 +matplotlib==3.7.1 +memory-profiler==0.61.0 +numexpr==2.8.4 +numpy==1.22.4 +pandas==1.5.3 +pathlib==1.0.1 +pip==23.1.2 +scikit-image==0.19.3 +scikit-learn==1.2.2 +scipy==1.10.1 +tensorflow==2.8.0 +tifffile==2023.7.18 +tqdm==4.65.0 wget==3.2 -wrapt==1.12.1