diff --git a/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb b/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb index a8836447..d3eb5354 100644 --- a/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "IkSguVy8Xv83" @@ -21,7 +20,7 @@ "\n", "This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", "\n", - "This notebook is based on the following paper: \n", + "This notebook is based on the following paper:\n", "\n", " **Image-to-Image Translation with Conditional Adversarial Networks** by Isola *et al.* on arXiv in 2016 (https://arxiv.org/abs/1611.07004)\n", "\n", @@ -31,7 +30,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "W7HfryEazzJE" @@ -121,7 +119,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "jWAz2i7RdxUV" @@ -154,7 +151,7 @@ "\n", "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", + "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here.\n", "\n", "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", "\n", @@ -170,7 +167,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "gKDLkLWUd-YX" @@ -216,7 +212,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "AdN8B91xZO0x" @@ -230,22 +225,22 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "fq21zJVFNASx", "pycharm": { "name": "#%%\n" - } + }, + "cellView": "form" }, "outputs": [], "source": [ - "Notebook_version = '1.15.3'\n", + "Notebook_version = '1.16.1'\n", "Network = 'pix2pix'\n", "\n", "\n", "from builtins import any as b_any\n", "\n", "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory \n", + " # Store requirements file in 'contents' directory\n", " current_dir = os.getcwd()\n", " dir_count = current_dir.count('/') - 1\n", " path = '../' * (dir_count) + 'requirements.txt'\n", @@ -272,7 +267,7 @@ "\n", " # Replace with package name and handle cases where import name is different to module name\n", " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", + " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n", " filtered_list = filter_files(req_list, mod_replace_list)\n", "\n", " file=open(path,'w')\n", @@ -289,11 +284,13 @@ "#Here, we install libraries which are not already included in Colab.\n", "!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n", "import os\n", - "os.chdir('pytorch-CycleGAN-and-pix2pix/')\n", + "pix2pix_code_dir = os.getcwd()\n", + "os.chdir(os.path.join(pix2pix_code_dir, \"pytorch-CycleGAN-and-pix2pix\"))\n", "!pip install -r requirements.txt\n", "!pip install fpdf2\n", "!pip install lpips\n", "\n", + "\n", "import lpips\n", "from PIL import Image\n", "import imageio\n", @@ -307,7 +304,7 @@ "from matplotlib import pyplot as plt\n", "import urllib\n", "import os, random\n", - "import shutil \n", + "import shutil\n", "import zipfile\n", "from tifffile import imread, imsave\n", "import time\n", @@ -316,8 +313,7 @@ "import pandas as pd\n", "import csv\n", "from glob import glob\n", - "from scipy import signal\n", - "from scipy import ndimage\n", + "from scipy import signal, ndimage, stats\n", "from skimage import io\n", "from sklearn.linear_model import LinearRegression\n", "from skimage.util import img_as_uint\n", @@ -327,11 +323,13 @@ "from astropy.visualization import simple_norm\n", "from skimage import img_as_float32\n", "from skimage.util import img_as_ubyte\n", - "from tqdm import tqdm \n", + "from tqdm import tqdm\n", "from fpdf import FPDF, HTMLMixin\n", "from datetime import datetime\n", "import subprocess\n", "from pip._internal.operations.freeze import freeze\n", + "import glob\n", + "import cv2\n", "\n", "# Colors for the warning messages\n", "class bcolors:\n", @@ -356,8 +354,314 @@ " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", "\n", "# average function\n", - "def Average(lst): \n", - " return sum(lst) / len(lst) \n", + "def Average(lst):\n", + " return sum(lst) / len(lst)\n", + "def ssim(img1, img2):\n", + " return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n", + "\n", + "\n", + "def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", + " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", + " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", + " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", + "\n", + "\n", + "def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):\n", + " x = x.astype(dtype,copy=False)\n", + " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", + " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", + " eps = dtype(eps)\n", + "\n", + " try:\n", + " import numexpr\n", + " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", + " except ImportError:\n", + " x = (x - mi) / ( ma - mi + eps )\n", + "\n", + " if clip:\n", + " x = np.clip(x,0,1)\n", + "\n", + " return x\n", + "\n", + "def norm_minmse(gt, x):\n", + " gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n", + " x = x.astype(np.float32, copy=False) - np.mean(x)\n", + " #x = x - np.mean(x)\n", + " gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n", + " #gt = gt - np.mean(gt)\n", + " scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n", + " return gt, scale * x\n", + "\n", + "def prepare_qc_dir(QC_model_path, QC_model_name, Im_path):\n", + " # Create a quality control/Prediction Folder\n", + " QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"\n", + "\n", + " if os.path.exists(QC_prediction_results):\n", + " print(\"The QC folder has been removed to save the new results\")\n", + " shutil.rmtree(QC_prediction_results)\n", + "\n", + " os.makedirs(QC_prediction_results, exist_ok=True)\n", + "\n", + " # Here we need to move the data to be analysed so that pix2pix can find them\n", + " Saving_path_QC = os.path.join(Im_path, QC_model_name+\"_images\")\n", + "\n", + " #if os.path.exists(Saving_path_QC):\n", + " # shutil.rmtree(Saving_path_QC)\n", + " os.makedirs(Saving_path_QC, exist_ok=True)\n", + "\n", + " # Folde to save the images in the correct format\n", + " Saving_path_QC_folder = os.path.join(Saving_path_QC, \"QC\")\n", + "\n", + " os.makedirs(Saving_path_QC_folder, exist_ok=True)\n", + "\n", + " imageA_folder = os.path.join(Saving_path_QC_folder, \"A\", \"test\")\n", + " os.makedirs(imageA_folder, exist_ok=True)\n", + "\n", + " imageB_folder = os.path.join(Saving_path_QC_folder, \"B\", \"test\")\n", + " os.makedirs(imageB_folder, exist_ok=True)\n", + "\n", + " imageAB_folder = os.path.join(Saving_path_QC_folder, \"AB\", \"test\")\n", + " os.makedirs(imageAB_folder, exist_ok=True)\n", + " return QC_prediction_results, Saving_path_QC_folder\n", + "\n", + "#Here we copy and normalise the data\n", + "def normalise_data(Source_QC_folder, Target_QC_folder, Normalisation_QC_source, Normalisation_QC_target, Im_path):\n", + " if Normalisation_QC_source == \"Contrast stretching\":\n", + "\n", + " for filename in os.listdir(Source_QC_folder):\n", + "\n", + " img = io.imread(os.path.join(Source_QC_folder,filename)).astype(np.float32)\n", + " short_name = os.path.splitext(filename)\n", + "\n", + " p2, p99 = np.percentile(img, (2, 99.9))\n", + " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", + "\n", + " img = 255 * img # Now scale by 255\n", + " img = img.astype(np.uint8)\n", + " cv2.imwrite(os.path.join(Im_path, \"A\", \"test\", f\"{short_name[0]}.png\"), img)\n", + "\n", + " if Normalisation_QC_target == \"Contrast stretching\":\n", + " for filename in os.listdir(Target_QC_folder):\n", + "\n", + " img = io.imread(os.path.join(Target_QC_folder,filename)).astype(np.float32)\n", + " short_name = os.path.splitext(filename)\n", + "\n", + " p2, p99 = np.percentile(img, (2, 99.9))\n", + " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", + "\n", + " img = 255 * img # Now scale by 255\n", + " img = img.astype(np.uint8)\n", + " cv2.imwrite(os.path.join(Im_path, \"B\", \"test\", f\"{short_name[0]}.png\"), img)\n", + "\n", + " if Normalisation_QC_source == \"Adaptive Equalization\":\n", + " for filename in os.listdir(Source_QC_folder):\n", + "\n", + " img = io.imread(os.path.join(Source_QC_folder,filename))\n", + " short_name = os.path.splitext(filename)\n", + "\n", + " img = exposure.equalize_adapthist(img, clip_limit=0.03)\n", + "\n", + " img = 255 * img # Now scale by 255\n", + " img = img.astype(np.uint8)\n", + " cv2.imwrite(os.path.join(Im_path, \"A\", \"test\", f\"{short_name[0]}.png\"), img)\n", + "\n", + "\n", + " if Normalisation_QC_target == \"Adaptive Equalization\":\n", + " for filename in os.listdir(Target_QC_folder):\n", + "\n", + " img = io.imread(os.path.join(Target_QC_folder,filename))\n", + " short_name = os.path.splitext(filename)\n", + "\n", + " img = exposure.equalize_adapthist(img, clip_limit=0.03)\n", + "\n", + " img = 255 * img # Now scale by 255\n", + " img = img.astype(np.uint8)\n", + " cv2.imwrite(os.path.join(Im_path, \"B\", \"test\", f\"{short_name[0]}.png\"), img)\n", + "\n", + " if Normalisation_QC_source == \"None\":\n", + " for files in os.listdir(Source_QC_folder):\n", + " shutil.copyfile(os.path.join(Source_QC_folder, files), os.path.join(Im_path, \"A\", \"test\", files))\n", + "\n", + " if Normalisation_QC_target == \"None\":\n", + " for files in os.listdir(Target_QC_folder):\n", + " shutil.copyfile(os.path.join(Target_QC_folder, files), os.path.join(Im_path, \"B\", \"test\", files))\n", + "\n", + "def QC_RGB(Source_QC_folder, QC_folder):\n", + "\n", + " # List images in Source_QC_folder\n", + " csv_file = os.path.join(QC_folder, f\"QC_metrics_{QC_model_name}{str(checkpoints)}.csv\")\n", + " # Open and create the csv file that will contain all the QC metrics\n", + " with open(csv_file, \"w\", newline='') as file:\n", + " writer = csv.writer(file)\n", + " # Write the header in the csv file\n", + " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\" , \"Prediction v. GT lpips\", \"Input v. GT lpips\"])\n", + " # Initiate list\n", + " ssim_score_list = []\n", + " lpips_score_list = []\n", + "\n", + " # Let's loop through the provided dataset in the QC folders\n", + " for i in os.listdir(Source_QC_folder):\n", + " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", + " print('Running QC on: '+i)\n", + "\n", + " shortname_no_PNG = i[:-4]\n", + "\n", + " # -------------------------------- Target test data (Ground truth) --------------------------------\n", + " test_GT = imageio.imread(os.path.join(QC_folder, f\"{shortname_no_PNG}_real_B.png\"))\n", + " # -------------------------------- Source test data --------------------------------\n", + " test_source = imageio.imread(os.path.join(QC_folder, f\"{shortname_no_PNG}_real_A.png\"))\n", + " # -------------------------------- Prediction --------------------------------\n", + " test_prediction = imageio.imread(os.path.join(QC_folder,f\"{shortname_no_PNG}_fake_B.png\"))\n", + " #--------------------------- Here we normalise using histograms matching--------------------------------\n", + " test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n", + " test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n", + " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", + " # Calculate the SSIM maps\n", + " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n", + " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n", + " ssim_score_list.append(index_SSIM_GTvsPrediction)\n", + "\n", + " #Save ssim_maps\n", + " img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n", + " io.imsave(os.path.join(QC_folder, f\"SSIM_GTvsPrediction_{shortname_no_PNG}.tif\"), img_SSIM_GTvsPrediction_8bit)\n", + "\n", + " img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n", + " io.imsave(os.path.join(QC_folder, f\"SSIM_GTvsSource_{shortname_no_PNG}.tif\"),img_SSIM_GTvsSource_8bit)\n", + "\n", + " # -------------------------------- Pearson correlation coefficient --------------------------------\n", + "\n", + " # -------------------------------- Calculate the perceptual difference metrics map and save them --------------------------------\n", + " if Do_lpips_analysis:\n", + "\n", + " lpips_GTvsPrediction = perceptual_diff(test_GT, test_prediction, 'alex', True)\n", + " lpips_GTvsPrediction_image = lpips_GTvsPrediction[0,0,...].data.cpu().numpy()\n", + " lpips_GTvsPrediction_score= lpips_GTvsPrediction.mean().data.numpy()\n", + " lpips_score_list.append(lpips_GTvsPrediction_score)\n", + "\n", + "\n", + " lpips_GTvsSource = perceptual_diff(test_GT, test_source, 'alex', True)\n", + " lpips_GTvsSource_image = lpips_GTvsSource[0,0,...].data.cpu().numpy()\n", + " lpips_GTvsSource_score= lpips_GTvsSource.mean().data.numpy()\n", + "\n", + "\n", + " #lpips_GTvsPrediction_image_8bit = (lpips_GTvsPrediction_image* 255).astype(\"uint8\")\n", + " io.imsave(os.path.join(QC_folder, f\"lpips_GTvsPrediction_{shortname_no_PNG}.tif\"),lpips_GTvsPrediction_image)\n", + "\n", + " #lpips_GTvsSource_image_8bit = (lpips_GTvsSource_image* 255).astype(\"uint8\")\n", + " io.imsave(os.path.join(QC_folder, f\"lpips_GTvsInput_{shortname_no_PNG}.tif\"),lpips_GTvsSource_image)\n", + " else:\n", + " lpips_GTvsPrediction_score = 0\n", + " lpips_score_list.append(lpips_GTvsPrediction_score)\n", + "\n", + " lpips_GTvsSource_score = 0\n", + "\n", + " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource), str(lpips_GTvsPrediction_score),str(lpips_GTvsSource_score)])\n", + "\n", + " #Here we calculate the ssim average for each image in each checkpoints\n", + " Average_SSIM_checkpoint = Average(ssim_score_list)\n", + " Average_lpips_checkpoint = Average(lpips_score_list)\n", + "\n", + " return Average_SSIM_checkpoint, Average_lpips_checkpoint\n", + "\n", + "def QC_singlechannel(Source_QC_folder, QC_folder):\n", + " csv_file = os.path.join(QC_folder, f\"QC_metrics_{QC_model_name}{str(checkpoints)}.csv\")\n", + " # Open and create the csv file that will contain all the QC metrics\n", + " with open(csv_file, \"w\", newline='') as file:\n", + " writer = csv.writer(file)\n", + "\n", + " # Write the header in the csv file\n", + " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\", \"Prediction v. GT lpips\", \"Input v. GT lpips\"])\n", + "\n", + " # Initialize the lists\n", + " ssim_score_list = []\n", + " Pearson_correlation_coefficient_list = []\n", + " lpips_score_list = []\n", + "\n", + " # Let's loop through the provided dataset in the QC folders\n", + " for i in os.listdir(Source_QC_folder):\n", + "\n", + " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", + " print('Running QC on: '+i)\n", + "\n", + " shortname_no_PNG = i[:-4]\n", + "\n", + " # -------------------------------- Target test data (Ground truth) --------------------------------\n", + " test_GT_raw = imageio.imread(os.path.join(QC_folder, f\"{shortname_no_PNG}_real_B.png\"))\n", + " test_GT = test_GT_raw[:,:,2]\n", + " # -------------------------------- Source test data --------------------------------\n", + " test_source_raw = imageio.imread(os.path.join(QC_folder, f\"{shortname_no_PNG}_real_A.png\"))\n", + " test_source = test_source_raw[:,:,2]\n", + " # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n", + " test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source)\n", + " # -------------------------------- Prediction --------------------------------\n", + " test_prediction_raw = imageio.imread(os.path.join(QC_folder,f\"{shortname_no_PNG}_fake_B.png\"))\n", + " test_prediction = test_prediction_raw[:,:,2]\n", + " # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n", + " test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction)\n", + "\n", + "\n", + " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", + " # Calculate the SSIM maps\n", + " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n", + " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n", + " ssim_score_list.append(index_SSIM_GTvsPrediction)\n", + "\n", + " #Save ssim_maps\n", + "\n", + " img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n", + " io.imsave(os.path.join(QC_folder, f\"SSIM_GTvsPrediction_{shortname_no_PNG}.tif\"),img_SSIM_GTvsPrediction_8bit)\n", + " img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n", + " io.imsave(os.path.join(QC_folder, f\"SSIM_GTvsSource_{shortname_no_PNG}.tif\"),img_SSIM_GTvsSource_8bit)\n", + "\n", + " # Calculate the Root Squared Error (RSE) maps\n", + " img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n", + " img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n", + "\n", + " # Save SE maps\n", + " img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n", + " io.imsave(os.path.join(QC_folder, f\"RSE_GTvsPrediction_{shortname_no_PNG}.tif\"),img_RSE_GTvsPrediction_8bit)\n", + " img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n", + " io.imsave(os.path.join(QC_folder, f\"RSE_GTvsSource_{shortname_no_PNG}.tif\"),img_RSE_GTvsSource_8bit)\n", + "\n", + " # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n", + " # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n", + " NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n", + " NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n", + "\n", + " # We can also measure the peak signal to noise ratio between the images\n", + " PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n", + " PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n", + "\n", + " # ---------------- Calculate the perceptual difference metrics map and save them -------------\n", + " if Do_lpips_analysis:\n", + " lpips_GTvsPrediction = perceptual_diff(test_GT_raw, test_prediction_raw, 'alex', True)\n", + " lpips_GTvsPrediction_image = lpips_GTvsPrediction[0,0,...].data.cpu().numpy()\n", + " lpips_GTvsPrediction_score= lpips_GTvsPrediction.mean().data.numpy()\n", + " lpips_score_list.append(lpips_GTvsPrediction_score)\n", + "\n", + " lpips_GTvsSource = perceptual_diff(test_GT_raw, test_source_raw, 'alex', True)\n", + " lpips_GTvsSource_image = lpips_GTvsSource[0,0,...].data.cpu().numpy()\n", + " lpips_GTvsSource_score= lpips_GTvsSource.mean().data.numpy()\n", + "\n", + "\n", + " lpips_GTvsPrediction_image_8bit = (lpips_GTvsPrediction_image* 255).astype(\"uint8\")\n", + " io.imsave(os.path.join(QC_folder, f\"lpips_GTvsPrediction_{shortname_no_PNG}.tif\"),lpips_GTvsPrediction_image_8bit)\n", + "\n", + " lpips_GTvsSource_image_8bit = (lpips_GTvsSource_image* 255).astype(\"uint8\")\n", + " io.imsave(os.path.join(QC_folder, f\"lpips_GTvsInput_{shortname_no_PNG}.tif\"),lpips_GTvsSource_image_8bit)\n", + " else:\n", + " lpips_GTvsPrediction_score = 0\n", + " lpips_score_list.append(lpips_GTvsPrediction_score)\n", + "\n", + " lpips_GTvsSource_score = 0\n", + "\n", + " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource),str(lpips_GTvsPrediction_score),str(lpips_GTvsSource_score)])\n", + "\n", + " #Here we calculate the ssim average for each image in each checkpoints\n", + " Average_SSIM_checkpoint = Average(ssim_score_list)\n", + " Average_lpips_checkpoint = Average(lpips_score_list)\n", + "\n", + " return Average_SSIM_checkpoint, Average_lpips_checkpoint\n", "\n", "def perceptual_diff(im0, im1, network, spatial):\n", "\n", @@ -370,24 +674,24 @@ "\n", " return diff\n", "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", + "def pdf_export(trained = False, augmentation = False, pretrained_model = False, Saving_path=\"/content/\"):\n", " class MyFPDF(FPDF, HTMLMixin):\n", " pass\n", "\n", " pdf = MyFPDF()\n", " pdf.add_page()\n", " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", + " pdf.set_font(\"Arial\", size = 11, style='B')\n", "\n", " Network = 'pix2pix'\n", " day = datetime.now()\n", " datetime_str = str(day)[0:10]\n", "\n", " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", + " pdf.multi_cell(180, 5, txt = Header, align = 'L')\n", " pdf.ln(1)\n", - " \n", - " # add another cell \n", + "\n", + " # add another cell\n", " if trained:\n", " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", @@ -463,7 +767,7 @@ " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", " pdf.ln(1)\n", - " html = \"\"\" \n", + " html = \"\"\"\n", " \n", " \n", " \n", @@ -516,8 +820,8 @@ " pdf.ln(1)\n", " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_pix2pix.png').shape\n", - " pdf.image('/content/TrainingDataExample_pix2pix.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", + " exp_size = io.imread(os.path.join(Saving_path, 'TrainingDataExample_pix2pix.png')).shape\n", + " pdf.image(os.path.join(Saving_path, 'TrainingDataExample_pix2pix.png'), x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", " pdf.ln(1)\n", " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", @@ -543,7 +847,7 @@ " pdf = MyFPDF()\n", " pdf.add_page()\n", " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", + " pdf.set_font(\"Arial\", size = 11, style='B')\n", "\n", " Network = 'pix2pix'\n", "\n", @@ -552,7 +856,7 @@ " datetime_str = str(day)[0:10]\n", "\n", " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", + " pdf.multi_cell(180, 5, txt = Header, align = 'L')\n", " pdf.ln(1)\n", "\n", " all_packages = ''\n", @@ -650,7 +954,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "n4yWFoJNnoin" @@ -669,7 +972,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "DMNHVZfHmbKb" @@ -692,8 +994,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", - "id": "zCvebubeSaGY" + "id": "zCvebubeSaGY", + "cellView": "form" }, "outputs": [], "source": [ @@ -702,8 +1004,8 @@ "\n", "import tensorflow as tf\n", "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", + " print('You do not have GPU access.')\n", + " print('Did you change your runtime ?')\n", " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", " print('Expect slow performance. To access GPU try reconnecting later')\n", "\n", @@ -713,7 +1015,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "sNIVx8_CLolt" @@ -723,19 +1024,13 @@ "---\n", " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", "\n", - " Play the cell below to mount your Google Drive and follow the instructions. \n", + " Play the cell below to mount your Google Drive and follow the instructions.\n", "\n", " Once this is done, your data are available in the **Files** tab on the top left of notebook." ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "01Djr8v-5pPk" - }, - "outputs": [], "source": [ "#@markdown ##Play the cell to connect your Google Drive to Colab\n", "\n", @@ -743,10 +1038,15 @@ "# mount user's Google Drive to Google Colab.\n", "from google.colab import drive\n", "drive.mount('/content/gdrive')" - ] + ], + "metadata": { + "cellView": "form", + "id": "R3JNPbRA6GAS" + }, + "execution_count": null, + "outputs": [] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "HLYcZR9gMv42" @@ -757,7 +1057,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "FQ_QxtSWQ7CL" @@ -765,11 +1064,10 @@ "source": [ "## **3.1. Setting main training parameters**\n", "---\n", - " " + "" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "AuESFimvMv43", @@ -804,7 +1102,7 @@ "\n", "**`patch_size`:** pix2pix divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 512**\n", "\n", - "**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n", + "**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.**\n", "\n", "**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n", "\n", @@ -839,22 +1137,18 @@ "\n", "#@markdown ###Image normalisation:\n", "\n", - "Normalisation_training_source = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", - "Normalisation_training_target = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", - "\n", - "\n", - "#Define where the patch file will be saved\n", - "base = \"/content\"\n", + "Normalisation_training_source = \"Contrast stretching\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", + "Normalisation_training_target = \"Contrast stretching\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", "\n", "# model name and path\n", "#@markdown ###Name of the model and path to model folder:\n", "model_name = \"\" #@param {type:\"string\"}\n", "model_path = \"\" #@param {type:\"string\"}\n", - "\n", + "pix2pix_working_directory = os.getcwd()\n", "# other parameters for training.\n", "#@markdown ###Training Parameters\n", "#@markdown Number of epochs:\n", - "number_of_epochs = 200#@param {type:\"number\"}\n", + "number_of_epochs = 10#@param {type:\"number\"}\n", "\n", "#@markdown ###Advanced Parameters\n", "\n", @@ -865,23 +1159,18 @@ "initial_learning_rate = 0.0002 #@param {type:\"number\"}\n", "\n", "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 1\n", - " patch_size = 512\n", - " initial_learning_rate = 0.0002\n", + "if (Use_Default_Advanced_Parameters):\n", + " print(\"Default advanced parameters enabled\")\n", + " batch_size = 1\n", + " patch_size = 512\n", + " initial_learning_rate = 0.0002\n", "\n", - "#here we check that no model with the same name already exist, if so delete\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n", - " print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n", - " \n", "#To use pix2pix we need to organise the data in a way the network can understand\n", "\n", - "Saving_path= \"/content/\"+model_name\n", + "Saving_path= os.path.join(pix2pix_working_directory, model_name)\n", "\n", "if os.path.exists(Saving_path):\n", - " shutil.rmtree(Saving_path)\n", + " shutil.rmtree(Saving_path)\n", "os.makedirs(Saving_path)\n", "\n", "imageA_folder = Saving_path+\"/A\"\n", @@ -895,7 +1184,7 @@ "\n", "TrainA_Folder = Saving_path+\"/A/train\"\n", "os.makedirs(TrainA_Folder)\n", - " \n", + "\n", "TrainB_Folder = Saving_path+\"/B/train\"\n", "os.makedirs(TrainB_Folder)\n", "\n", @@ -909,82 +1198,75 @@ "# Here we normalise the image is enabled\n", "\n", "if Normalisation_training_source == \"Contrast stretching\":\n", + " Training_source_norm = Saving_path+\"/Training_source_norm\"\n", + " os.makedirs(Training_source_norm)\n", + " for filename in os.listdir(Training_source):\n", + " img = io.imread(os.path.join(Training_source,filename)).astype(np.float32)\n", + " short_name = os.path.splitext(filename)\n", + " p2, p99 = np.percentile(img, (2, 99.9))\n", + " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", + " img = 255 * img # Now scale by 255\n", + " img = img.astype(np.uint8)\n", + " cv2.imwrite(Training_source_norm+\"/\"+short_name[0]+\".png\", img)\n", "\n", - " Training_source_norm = Saving_path+\"/Training_source_norm\"\n", - " os.makedirs(Training_source_norm)\n", - " \n", - " for filename in os.listdir(Training_source):\n", - "\n", - " img = io.imread(os.path.join(Training_source,filename)).astype(np.float32)\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " p2, p99 = np.percentile(img, (2, 99.9))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(Training_source_norm+\"/\"+short_name[0]+\".png\", img)\n", - " \n", - " Training_source = Training_source_norm\n", + " Training_source = Training_source_norm\n", "\n", "\n", "if Normalisation_training_target == \"Contrast stretching\":\n", + " Training_target_norm = Saving_path+\"/Training_target_norm\"\n", + " os.makedirs(Training_target_norm)\n", + " for filename in os.listdir(Training_target):\n", "\n", - " Training_target_norm = Saving_path+\"/Training_target_norm\"\n", - " os.makedirs(Training_target_norm)\n", - "\n", - " for filename in os.listdir(Training_target):\n", - "\n", - " img = io.imread(os.path.join(Training_target,filename)).astype(np.float32)\n", - " short_name = os.path.splitext(filename)\n", + " img = io.imread(os.path.join(Training_target,filename)).astype(np.float32)\n", + " short_name = os.path.splitext(filename)\n", "\n", - " p2, p99 = np.percentile(img, (2, 99.9))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", + " p2, p99 = np.percentile(img, (2, 99.9))\n", + " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(Training_target_norm+\"/\"+short_name[0]+\".png\", img)\n", + " img = 255 * img # Now scale by 255\n", + " img = img.astype(np.uint8)\n", + " cv2.imwrite(Training_target_norm+\"/\"+short_name[0]+\".png\", img)\n", "\n", - " Training_target = Training_target_norm\n", + " Training_target = Training_target_norm\n", "\n", "\n", "if Normalisation_training_source == \"Adaptive Equalization\":\n", - " Training_source_norm = Saving_path+\"/Training_source_norm\"\n", - " os.makedirs(Training_source_norm)\n", + " Training_source_norm = Saving_path+\"/Training_source_norm\"\n", + " os.makedirs(Training_source_norm)\n", "\n", - " for filename in os.listdir(Training_source):\n", + " for filename in os.listdir(Training_source):\n", "\n", - " img = io.imread(os.path.join(Training_source,filename))\n", - " short_name = os.path.splitext(filename)\n", + " img = io.imread(os.path.join(Training_source,filename))\n", + " short_name = os.path.splitext(filename)\n", "\n", - " img = exposure.equalize_adapthist(img, clip_limit=0.03)\n", + " img = exposure.equalize_adapthist(img, clip_limit=0.03)\n", "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(Training_source_norm+\"/\"+short_name[0]+\".png\", img)\n", + " img = 255 * img # Now scale by 255\n", + " img = img.astype(np.uint8)\n", + " cv2.imwrite(Training_source_norm+\"/\"+short_name[0]+\".png\", img)\n", "\n", "\n", - " Training_source = Training_source_norm\n", + " Training_source = Training_source_norm\n", "\n", "\n", "if Normalisation_training_target == \"Adaptive Equalization\":\n", "\n", - " Training_target_norm = Saving_path+\"/Training_target_norm\"\n", - " os.makedirs(Training_target_norm)\n", + " Training_target_norm = Saving_path+\"/Training_target_norm\"\n", + " os.makedirs(Training_target_norm)\n", "\n", - " for filename in os.listdir(Training_target):\n", + " for filename in os.listdir(Training_target):\n", "\n", - " img = io.imread(os.path.join(Training_target,filename))\n", - " short_name = os.path.splitext(filename)\n", + " img = io.imread(os.path.join(Training_target,filename))\n", + " short_name = os.path.splitext(filename)\n", "\n", - " p2, p99 = np.percentile(img, (2, 99.8))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", + " p2, p99 = np.percentile(img, (2, 99.8))\n", + " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(Training_target_norm+\"/\"+short_name[0]+\".png\", img)\n", + " img = 255 * img # Now scale by 255\n", + " img = img.astype(np.uint8)\n", + " cv2.imwrite(Training_target_norm+\"/\"+short_name[0]+\".png\", img)\n", "\n", - " Training_target = Training_target_norm\n", + " Training_target = Training_target_norm\n", "\n", "# This will display a randomly chosen dataset input and output\n", "random_choice = random.choice(os.listdir(Training_source))\n", @@ -998,8 +1280,8 @@ "\n", "#Hyperparameters failsafes\n", "if patch_size > min(Image_Y, Image_X):\n", - " patch_size = min(Image_Y, Image_X)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n", + " patch_size = min(Image_Y, Image_X)\n", + " print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n", "\n", "# Here we check that patch_size is divisible by 4\n", "if not patch_size % 4 == 0:\n", @@ -1008,8 +1290,8 @@ "\n", "# Here we check that patch_size is at least bigger than 256\n", "if patch_size < 256:\n", - " patch_size = 256\n", - " print (bcolors.WARNING + \" Your chosen patch_size is too small; therefore the patch_size chosen is now:\",patch_size)\n", + " patch_size = 256\n", + " print (bcolors.WARNING + \" Your chosen patch_size is too small; therefore the patch_size chosen is now:\",patch_size)\n", "\n", "y = io.imread(Training_target+\"/\"+random_choice)\n", "\n", @@ -1017,14 +1299,14 @@ "n_channel_y = 1 if y.ndim == 2 else y.shape[-1]\n", "\n", "if n_channel_x == 1:\n", - " cmap_x = 'gray'\n", + " cmap_x = 'gray'\n", "else:\n", - " cmap_x = None\n", + " cmap_x = None\n", "\n", "if n_channel_y == 1:\n", - " cmap_y = 'gray'\n", + " cmap_y = 'gray'\n", "else:\n", - " cmap_y = None\n", + " cmap_y = None\n", "\n", "f=plt.figure(figsize=(16,8))\n", "plt.subplot(1,2,1)\n", @@ -1037,14 +1319,12 @@ "plt.title('Training target')\n", "plt.axis('off');\n", "\n", - "plt.savefig('/content/TrainingDataExample_pix2pix.png',bbox_inches='tight',pad_inches=0)" + "plt.savefig(os.path.join(Saving_path, 'TrainingDataExample_pix2pix.png'),bbox_inches='tight',pad_inches=0)" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "hqP3HLsM1w3G", "pycharm": { "name": "#%% md\n" @@ -1057,10 +1337,8 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "eqR8q_AK1w3G", "pycharm": { "name": "#%% md\n" @@ -1075,7 +1353,7 @@ "\n", "Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n", "\n", - "**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** " + "**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.**" ] }, { @@ -1092,16 +1370,16 @@ "source": [ "#Data augmentation\n", "\n", - "Use_Data_augmentation = False #@param {type:\"boolean\"}\n", + "Use_Data_augmentation = True #@param {type:\"boolean\"}\n", "\n", "if Use_Data_augmentation:\n", - " !pip install Augmentor\n", - " import Augmentor\n", + " !pip install Augmentor\n", + " import Augmentor\n", "\n", "\n", "#@markdown ####Choose a factor by which you want to multiply your original dataset\n", "\n", - "Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:30, step:1}\n", + "Multiply_dataset_by = 3 #@param {type:\"slider\", min:1, max:30, step:1}\n", "\n", "Save_augmented_images = False #@param {type:\"boolean\"}\n", "\n", @@ -1131,7 +1409,7 @@ "random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", "\n", "\n", - "#@markdown ####Image shearing and skewing \n", + "#@markdown ####Image shearing and skewing\n", "\n", "image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", "max_image_shear = 10 #@param {type:\"slider\", min:1, max:25, step:1}\n", @@ -1142,37 +1420,37 @@ "\n", "\n", "if Use_Default_Augmentation_Parameters:\n", - " rotate_90_degrees = 0.5\n", - " rotate_270_degrees = 0.5\n", - " flip_left_right = 0.5\n", - " flip_top_bottom = 0.5\n", - "\n", - " if not Multiply_dataset_by >5:\n", - " random_zoom = 0\n", - " random_zoom_magnification = 0.9\n", - " random_distortion = 0\n", - " image_shear = 0\n", - " max_image_shear = 10\n", - " skew_image = 0\n", - " skew_image_magnitude = 0\n", - "\n", - " if Multiply_dataset_by >5:\n", - " random_zoom = 0.1\n", - " random_zoom_magnification = 0.9\n", - " random_distortion = 0.5\n", - " image_shear = 0.2\n", - " max_image_shear = 5\n", - " skew_image = 0.2\n", - " skew_image_magnitude = 0.4\n", - "\n", - " if Multiply_dataset_by >25:\n", - " random_zoom = 0.5\n", - " random_zoom_magnification = 0.8\n", - " random_distortion = 0.5\n", - " image_shear = 0.5\n", - " max_image_shear = 20\n", - " skew_image = 0.5\n", - " skew_image_magnitude = 0.6\n", + " rotate_90_degrees = 0.5\n", + " rotate_270_degrees = 0.5\n", + " flip_left_right = 0.5\n", + " flip_top_bottom = 0.5\n", + "\n", + " if not Multiply_dataset_by >5:\n", + " random_zoom = 0\n", + " random_zoom_magnification = 0.9\n", + " random_distortion = 0\n", + " image_shear = 0\n", + " max_image_shear = 10\n", + " skew_image = 0\n", + " skew_image_magnitude = 0\n", + "\n", + " if Multiply_dataset_by >5:\n", + " random_zoom = 0.1\n", + " random_zoom_magnification = 0.9\n", + " random_distortion = 0.5\n", + " image_shear = 0.2\n", + " max_image_shear = 5\n", + " skew_image = 0.2\n", + " skew_image_magnitude = 0.4\n", + "\n", + " if Multiply_dataset_by >25:\n", + " random_zoom = 0.5\n", + " random_zoom_magnification = 0.8\n", + " random_distortion = 0.5\n", + " image_shear = 0.5\n", + " max_image_shear = 20\n", + " skew_image = 0.5\n", + " skew_image_magnitude = 0.6\n", "\n", "\n", "list_files = os.listdir(Training_source)\n", @@ -1182,101 +1460,100 @@ "\n", "\n", "if Use_Data_augmentation:\n", - " print(\"Data augmentation enabled\")\n", + " print(\"Data augmentation enabled\")\n", "# Here we set the path for the various folder were the augmented images will be loaded\n", "\n", "# All images are first saved into the augmented folder\n", " #Augmented_folder = \"/content/Augmented_Folder\"\n", - " \n", - " if not Save_augmented_images:\n", - " Saving_path= \"/content\"\n", "\n", - " Augmented_folder = Saving_path+\"/Augmented_Folder\"\n", - " if os.path.exists(Augmented_folder):\n", - " shutil.rmtree(Augmented_folder)\n", - " os.makedirs(Augmented_folder)\n", + " if not Save_augmented_images:\n", + " Saving_path= \"/content\"\n", + "\n", + " Augmented_folder = Saving_path+\"/Augmented_Folder\"\n", + " if os.path.exists(Augmented_folder):\n", + " shutil.rmtree(Augmented_folder)\n", + " os.makedirs(Augmented_folder)\n", "\n", - " #Training_source_augmented = \"/content/Training_source_augmented\"\n", - " Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n", + " #Training_source_augmented = \"/content/Training_source_augmented\"\n", + " Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n", "\n", - " if os.path.exists(Training_source_augmented):\n", - " shutil.rmtree(Training_source_augmented)\n", - " os.makedirs(Training_source_augmented)\n", + " if os.path.exists(Training_source_augmented):\n", + " shutil.rmtree(Training_source_augmented)\n", + " os.makedirs(Training_source_augmented)\n", "\n", - " #Training_target_augmented = \"/content/Training_target_augmented\"\n", - " Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n", + " #Training_target_augmented = \"/content/Training_target_augmented\"\n", + " Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n", "\n", - " if os.path.exists(Training_target_augmented):\n", - " shutil.rmtree(Training_target_augmented)\n", - " os.makedirs(Training_target_augmented)\n", + " if os.path.exists(Training_target_augmented):\n", + " shutil.rmtree(Training_target_augmented)\n", + " os.makedirs(Training_target_augmented)\n", "\n", "\n", "# Here we generate the augmented images\n", "#Load the images\n", - " p = Augmentor.Pipeline(Training_source, Augmented_folder)\n", + " p = Augmentor.Pipeline(Training_source, Augmented_folder)\n", "\n", "#Define the matching images\n", - " p.ground_truth(Training_target)\n", + " p.ground_truth(Training_target)\n", "#Define the augmentation possibilities\n", - " if not rotate_90_degrees == 0:\n", - " p.rotate90(probability=rotate_90_degrees)\n", - " \n", - " if not rotate_270_degrees == 0:\n", - " p.rotate270(probability=rotate_270_degrees)\n", + " if not rotate_90_degrees == 0:\n", + " p.rotate90(probability=rotate_90_degrees)\n", + "\n", + " if not rotate_270_degrees == 0:\n", + " p.rotate270(probability=rotate_270_degrees)\n", + "\n", + " if not flip_left_right == 0:\n", + " p.flip_left_right(probability=flip_left_right)\n", "\n", - " if not flip_left_right == 0:\n", - " p.flip_left_right(probability=flip_left_right)\n", + " if not flip_top_bottom == 0:\n", + " p.flip_top_bottom(probability=flip_top_bottom)\n", "\n", - " if not flip_top_bottom == 0:\n", - " p.flip_top_bottom(probability=flip_top_bottom)\n", + " if not random_zoom == 0:\n", + " p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n", "\n", - " if not random_zoom == 0:\n", - " p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n", - " \n", - " if not random_distortion == 0:\n", - " p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n", + " if not random_distortion == 0:\n", + " p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n", "\n", - " if not image_shear == 0:\n", - " p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n", - " \n", - " if not skew_image == 0:\n", - " p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n", + " if not image_shear == 0:\n", + " p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n", "\n", - " p.sample(int(Nb_augmented_files))\n", + " if not skew_image == 0:\n", + " p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n", "\n", - " print(int(Nb_augmented_files),\"matching images generated\")\n", + " p.sample(int(Nb_augmented_files))\n", + "\n", + " print(int(Nb_augmented_files),\"matching images generated\")\n", "\n", "# Here we sort through the images and move them back to augmented trainning source and targets folders\n", "\n", - " augmented_files = os.listdir(Augmented_folder)\n", + " augmented_files = os.listdir(Augmented_folder)\n", + "\n", + " for f in augmented_files:\n", "\n", - " for f in augmented_files:\n", + " if (f.startswith(\"_groundtruth_(1)_\")):\n", + " shortname_noprefix = f[17:]\n", + " shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix)\n", + " if not (f.startswith(\"_groundtruth_(1)_\")):\n", + " shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n", "\n", - " if (f.startswith(\"_groundtruth_(1)_\")):\n", - " shortname_noprefix = f[17:]\n", - " shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n", - " if not (f.startswith(\"_groundtruth_(1)_\")):\n", - " shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n", - " \n", "\n", - " for filename in os.listdir(Training_source_augmented):\n", - " os.chdir(Training_source_augmented)\n", - " os.rename(filename, filename.replace('_original', ''))\n", - " \n", - " #Here we clean up the extra files\n", - " shutil.rmtree(Augmented_folder)\n", + " for filename in os.listdir(Training_source_augmented):\n", + " # TODO: remove this chdir\n", + " os.chdir(Training_source_augmented)\n", + " os.rename(filename, filename.replace('_original', ''))\n", + "\n", + " #Here we clean up the extra files\n", + " shutil.rmtree(Augmented_folder)\n", "\n", "if not Use_Data_augmentation:\n", - " print(bcolors.WARNING+\"Data augmentation disabled\") \n", + " print(bcolors.WARNING+\"Data augmentation disabled\")\n", "\n", "\n" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "hbrkdhUO1w3H", "pycharm": { "name": "#%% md\n" @@ -1286,7 +1563,7 @@ "\n", "## **3.3. Using weights from a pre-trained model as initial weights**\n", "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a pix2pix model**. \n", + " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a pix2pix model**.\n", "\n", " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n" ] @@ -1314,30 +1591,24 @@ "\n", "# --------------------- Check if we load a previously trained model ------------------------\n", "if Use_pretrained_model:\n", + " h5_file_path = os.path.join(pretrained_model_path, \"latest_net_G.pth\")\n", + " # --------------------- Check the model exist ------------------------\n", "\n", - " h5_file_path = os.path.join(pretrained_model_path, \"latest_net_G.pth\")\n", - " \n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "\n", - " if not os.path.exists(h5_file_path):\n", - " print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n", - " Use_pretrained_model = False\n", - " print(bcolors.WARNING+'No pretrained network will be used.')\n", + " if not os.path.exists(h5_file_path):\n", + " print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n", + " Use_pretrained_model = False\n", + " print(bcolors.WARNING+'No pretrained network will be used.')\n", "\n", - " if os.path.exists(h5_file_path):\n", - " print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n", - " \n", + " if os.path.exists(h5_file_path):\n", + " print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n", "else:\n", - " print(bcolors.WARNING+'No pretrained network will be used.')\n", + " print(bcolors.WARNING+'No pretrained network will be used.')\n", "\n" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "wTFo3ECz1w3I", "pycharm": { "name": "#%% md\n" @@ -1349,10 +1620,8 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "G3DrbrSK1w3I", "pycharm": { "name": "#%% md\n" @@ -1371,11 +1640,11 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "dj-IzrWz1w3I", "pycharm": { "name": "#%%\n" - } + }, + "cellView": "form" }, "outputs": [], "source": [ @@ -1385,18 +1654,18 @@ "# --------------------- Here we load the augmented data or the raw data ------------------------\n", "\n", "if Use_Data_augmentation:\n", - " Training_source_dir = Training_source_augmented\n", - " Training_target_dir = Training_target_augmented\n", + " Training_source_dir = Training_source_augmented\n", + " Training_target_dir = Training_target_augmented\n", "\n", "if not Use_Data_augmentation:\n", - " Training_source_dir = Training_source\n", - " Training_target_dir = Training_target\n", + " Training_source_dir = Training_source\n", + " Training_target_dir = Training_target\n", "# --------------------- ------------------------------------------------\n", "\n", "print(\"Data preparation in progress\")\n", "\n", "if os.path.exists(model_path+'/'+model_name):\n", - " shutil.rmtree(model_path+'/'+model_name)\n", + " shutil.rmtree(model_path+'/'+model_name)\n", "os.makedirs(model_path+'/'+model_name)\n", "\n", "#--------------- Here we move the files to train A and train B ---------\n", @@ -1413,8 +1682,9 @@ "#---------------------------------------------------------------------\n", "\n", "#--------------- Here we combined A and B images---------\n", - "os.chdir(\"/content\")\n", - "!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n", + "# TODO: check if we can remove this chdir\n", + "os.chdir(pix2pix_code_dir)\n", + "!python3 pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n", "\n", "\n", "\n", @@ -1424,22 +1694,21 @@ "number_of_epochs_lr_decay = int(number_of_epochs/2)\n", "\n", "if Use_pretrained_model :\n", - " for f in os.listdir(pretrained_model_path):\n", - " if (f.startswith(\"latest_net_\")): \n", - " shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n", + " for f in os.listdir(pretrained_model_path):\n", + " if (f.startswith(\"latest_net_\")):\n", + " shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n", "\n", "#Export of pdf summary of training parameters\n", - "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n", + "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model,\n", + " Saving_path=os.path.join(pix2pix_working_directory, model_name))\n", "\n", "print('------------------------')\n", "print(\"Data ready for training\")\n" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "3pEg3UzF1w3J", "pycharm": { "name": "#%% md\n" @@ -1459,11 +1728,12 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "92BQcn-w1w3J", "pycharm": { "name": "#%%\n" - } + }, + "scrolled": true, + "cellView": "form" }, "outputs": [], "source": [ @@ -1471,14 +1741,12 @@ "\n", "# get number of channels\n", "if number_channels == \"1\":\n", - " nc = 1\n", + " nc = 1\n", "elif number_channels == \"3\":\n", - " nc = 3\n", + " nc = 3\n", "\n", "start = time.time()\n", "\n", - "os.chdir(\"/content\")\n", - "\n", "#--------------------------------- Command line inputs to change pix2pix paramaters------------\n", "\n", " # basic parameters\n", @@ -1486,7 +1754,7 @@ " #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n", " #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n", " #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n", - " \n", + "\n", " # model parameters\n", " #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n", " #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n", @@ -1500,7 +1768,7 @@ " #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n", " #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n", " #('--no_dropout', action='store_true', help='no dropout for the generator')\n", - " \n", + "\n", " # dataset parameters\n", " #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n", " #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n", @@ -1513,13 +1781,13 @@ " #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n", " #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n", " #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n", - " \n", + "\n", " # additional parameters\n", " #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n", " #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n", " #('--verbose', action='store_true', help='if specified, print more debugging information')\n", " #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n", - " \n", + "\n", " # visdom and HTML visualization parameters\n", " #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n", " #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n", @@ -1530,7 +1798,7 @@ " #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n", " #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n", " #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n", - " \n", + "\n", " # network saving and loading parameters\n", " #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n", " #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n", @@ -1538,7 +1806,7 @@ " #('--continue_train', action='store_true', help='continue training: load the latest model')\n", " #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n", " #('--phase', type=str, default='train', help='train, val, test, etc')\n", - " \n", + "\n", " # training parameters\n", " #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n", " #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n", @@ -1553,10 +1821,10 @@ "\n", "#----- Start the training ------------------------------------\n", "if not Use_pretrained_model:\n", - " !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --input_nc \"$nc\" --output_nc \"$nc\" --dataset_mode \"aligned\"\n", + " !python3 pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --input_nc \"$nc\" --output_nc \"$nc\" --dataset_mode \"aligned\"\n", "\n", "if Use_pretrained_model:\n", - " !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train --input_nc \"$nc\" --output_nc \"$nc\" --dataset_mode \"aligned\"\n", + " !python3 pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train --input_nc \"$nc\" --output_nc \"$nc\" --dataset_mode \"aligned\"\n", "\n", "\n", "#---------------------------------------------------------\n", @@ -1565,19 +1833,18 @@ "\n", "# Displaying the time elapsed for training\n", "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", + "mins, sec = divmod(dt, 60)\n", + "hour, mins = divmod(mins, 60)\n", "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", "\n", "# Export pdf summary after training to update document\n", - "pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n" + "pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model,\n", + " Saving_path=os.path.join(pix2pix_working_directory, model_name))\n" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "aMSX_rxu1w3K", "pycharm": { "name": "#%% md\n" @@ -1587,17 +1854,15 @@ "# **5. Evaluate your model**\n", "---\n", "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", + "This section allows the user to perform important quality checks on the validity and generalisability of the trained model.\n", "\n", "**We highly recommend to perform quality control on all newly trained models.**\n", "\n" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "9TpiOXpP1w3K", "pycharm": { "name": "#%% md\n" @@ -1621,35 +1886,36 @@ "source": [ "# model name and path\n", "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = False #@param {type:\"boolean\"}\n", + "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", "\n", "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", + "# model name and path\n", + "#@markdown ###Name of the model and path to model folder:\n", "QC_model_folder = \"\" #@param {type:\"string\"}\n", "\n", "#Here we define the loaded model name and path\n", "QC_model_name = os.path.basename(QC_model_folder)\n", "QC_model_path = os.path.dirname(QC_model_folder)\n", "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_name = model_name\n", - " QC_model_path = model_path\n", + "if (Use_the_current_trained_model):\n", + " QC_model_name = model_name\n", + " QC_model_path = model_path\n", + "else:\n", + " pix2pix_working_directory = os.getcwd()\n", "\n", "full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n", "if os.path.exists(full_QC_model_path):\n", - " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", + " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" + " W = '\\033[0m' # white (normal)\n", + " R = '\\033[31m' # red\n", + " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", + " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "55F-b-gO1w3K", "pycharm": { "name": "#%% md\n" @@ -1660,10 +1926,8 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "6sL8A89z1w3K", "pycharm": { "name": "#%% md\n" @@ -1674,15 +1938,15 @@ "\n", "This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n", "\n", - "**1. The SSIM (structural similarity) map** \n", + "**1. The SSIM (structural similarity) map**\n", "\n", - "The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n", + "The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info).\n", "\n", "**mSSIM** is the SSIM value calculated across the entire window of both images.\n", "\n", "**The output below shows the SSIM maps with the mSSIM**\n", "\n", - "**2. The RSE (Root Squared Error) map** \n", + "**2. The RSE (Root Squared Error) map**\n", "\n", "This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n", "\n", @@ -1697,523 +1961,175 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "egpLYnYp1w3L", - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], "source": [ "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "\n", - "import glob\n", - "import os.path\n", - "from scipy import stats\n", - "\n", "#@markdown ###Path to images:\n", - "\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", + "Source_QC_folder = \"\" #@param{type:\"string\"}\n", "Target_QC_folder = \"\" #@param{type:\"string\"}\n", "\n", + "#@markdown ###Do you need to prepare test data again?\n", + "prepare_testdata = False #@param {type:\"boolean\"}\n", + "#@markdown If not, provide the path containing the images, e.g., \"/content/my_model_images/QC\"\n", + "path2im = ''#@param{type:\"string\"}\n", + "\n", "#@markdown ###Number of channels:\n", "\n", "number_channels = \"1\" #@param [\"1\", \"3\"]\n", "\n", "# get number of channels\n", "if number_channels == \"1\":\n", - " nc = 1\n", + " nc = 1\n", "elif number_channels == \"3\":\n", - " nc = 3\n", + " nc = 3\n", "\n", "#@markdown ###Image normalisation:\n", + "Normalisation_QC_source = \"Contrast stretching\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", + "Normalisation_QC_target = \"Contrast stretching\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", "\n", - "Normalisation_QC_source = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", - "Normalisation_QC_target = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", - "\n", + "#@markdown ###Did you evaluate all the checkpoints and only want to visualise the results? (It reduces significantly the processing time if you did it already)\n", + "avoid_evaluating_again = False #@param {type:\"boolean\"}\n", + "#@markdown ####Choose the frequency of checkpoints to evaluate. If 1, it will evaluate all the model checkpoints available.\n", + "QC_evaluation_checkpoint_freq = 1 #@param {type:\"number\"}\n", + "QC_freq = np.int(QC_evaluation_checkpoint_freq)\n", "\n", "#@markdown ###Advanced Parameters\n", - "\n", "patch_size_QC = 1024#@param {type:\"number\"} # in pixels\n", "Do_lpips_analysis = True #@param {type:\"boolean\"}\n", "\n", "\n", "\n", - "# Create a quality control folder\n", - "\n", - "if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n", - " shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n", - "\n", - "os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n", - "\n", + "if not avoid_evaluating_again:\n", + " # Create a quality control folder\n", + " if os.path.exists(os.path.join(QC_model_path, QC_model_name, \"Quality Control\")):\n", + " shutil.rmtree(os.path.join(QC_model_path, QC_model_name, \"Quality Control\"))\n", + " os.makedirs(os.path.join(QC_model_path, QC_model_name, \"Quality Control\"))\n", "\n", "# Create a quality control/Prediction Folder\n", + "if prepare_testdata or not avoid_evaluating_again:\n", + " QC_prediction_results, path2im = prepare_qc_dir(QC_model_path, QC_model_name, pix2pix_working_directory)\n", + "else:\n", + " QC_prediction_results = os.path.join(QC_model_path, QC_model_name, \"Quality Control\", \"Prediction\")\n", "\n", - "QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"\n", - "\n", - "if os.path.exists(QC_prediction_results):\n", - " shutil.rmtree(QC_prediction_results)\n", - "\n", - "os.makedirs(QC_prediction_results)\n", - "\n", - "# Here we count how many images are in our folder to be predicted and we had a few\n", - "Nb_files_Data_folder = len(os.listdir(Source_QC_folder)) +10\n", - "\n", - "# List images in Source_QC_folder\n", - "# This will find the image dimension of a randomly choosen image in Source_QC_folder \n", - "random_choice = random.choice(os.listdir(Source_QC_folder))\n", - "x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "Image_min_dim = min(Image_Y, Image_X)\n", - "\n", - "# Here we need to move the data to be analysed so that pix2pix can find them\n", - "\n", - "Saving_path_QC= \"/content/\"+QC_model_name+\"_images\"\n", - "\n", - "if os.path.exists(Saving_path_QC):\n", - " shutil.rmtree(Saving_path_QC)\n", - "os.makedirs(Saving_path_QC)\n", - "\n", - "Saving_path_QC_folder = Saving_path_QC+\"/QC\"\n", - "\n", - "if os.path.exists(Saving_path_QC_folder):\n", - " shutil.rmtree(Saving_path_QC_folder)\n", - "os.makedirs(Saving_path_QC_folder)\n", - "\n", - "imageA_folder = Saving_path_QC_folder+\"/A\"\n", - "os.makedirs(imageA_folder)\n", - "\n", - "imageB_folder = Saving_path_QC_folder+\"/B\"\n", - "os.makedirs(imageB_folder)\n", - "\n", - "imageAB_folder = Saving_path_QC_folder+\"/AB\"\n", - "os.makedirs(imageAB_folder)\n", - "\n", - "testAB_folder = Saving_path_QC_folder+\"/AB/test\"\n", - "os.makedirs(testAB_folder)\n", - "\n", - "testA_Folder = Saving_path_QC_folder+\"/A/test\"\n", - "os.makedirs(testA_Folder)\n", - " \n", - "testB_Folder = Saving_path_QC_folder+\"/B/test\"\n", - "os.makedirs(testB_Folder)\n", - "\n", - "QC_checkpoint_folders = \"/content/\"+QC_model_name\n", - "\n", - "if os.path.exists(QC_checkpoint_folders):\n", - " shutil.rmtree(QC_checkpoint_folders)\n", - "os.makedirs(QC_checkpoint_folders)\n", - "\n", - "#Here we copy and normalise the data\n", - "\n", - "if Normalisation_QC_source == \"Contrast stretching\":\n", - " \n", - " for filename in os.listdir(Source_QC_folder):\n", - "\n", - " img = io.imread(os.path.join(Source_QC_folder,filename)).astype(np.float32)\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " p2, p99 = np.percentile(img, (2, 99.9))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n", - " \n", - "if Normalisation_QC_target == \"Contrast stretching\":\n", - "\n", - " for filename in os.listdir(Target_QC_folder):\n", - "\n", - " img = io.imread(os.path.join(Target_QC_folder,filename)).astype(np.float32)\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " p2, p99 = np.percentile(img, (2, 99.9))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n", - "\n", - "if Normalisation_QC_source == \"Adaptive Equalization\":\n", - " for filename in os.listdir(Source_QC_folder):\n", - "\n", - " img = io.imread(os.path.join(Source_QC_folder,filename))\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " img = exposure.equalize_adapthist(img, clip_limit=0.03)\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n", - "\n", - "\n", - "if Normalisation_QC_target == \"Adaptive Equalization\":\n", - "\n", - " for filename in os.listdir(Target_QC_folder):\n", - "\n", - " img = io.imread(os.path.join(Target_QC_folder,filename))\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " img = exposure.equalize_adapthist(img, clip_limit=0.03)\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n", - "\n", - "if Normalisation_QC_source == \"None\":\n", - " for files in os.listdir(Source_QC_folder):\n", - " shutil.copyfile(Source_QC_folder+\"/\"+files, testA_Folder+\"/\"+files)\n", - "\n", - "if Normalisation_QC_target == \"None\":\n", - " for files in os.listdir(Target_QC_folder):\n", - " shutil.copyfile(Target_QC_folder+\"/\"+files, testB_Folder+\"/\"+files)\n", - "\n", - "\n", - "#Here we create a merged folder containing only imageA\n", - "os.chdir(\"/content\")\n", + "if not avoid_evaluating_again:\n", "\n", - "!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n", + " print(\"-------------------------------------------------------------\")\n", + " print(\"Path where the predictions are stored\")\n", + " print(QC_prediction_results)\n", + " print(\"Path where test images are prepared for testing\")\n", + " print(path2im)\n", + " print(\"-------------------------------------------------------------\")\n", "\n", - "# This will find the image dimension of a randomly choosen image in Source_QC_folder \n", - "random_choice = random.choice(os.listdir(Source_QC_folder))\n", - "x = io.imread(Source_QC_folder+\"/\"+random_choice)\n", + " # Here we count how many images are in our folder to be predicted and we had a few\n", + " Nb_files_Data_folder = len(os.listdir(os.path.join(path2im, \"A\", \"test\")))\n", + " #Nb_files_Data_folder = len(os.listdir(Source_QC_folder)) + 10\n", "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", + " # List images in Source_QC_folder\n", "\n", - "Image_min_dim = int(min(Image_Y, Image_X))\n", + " #Here we copy and normalise the data\n", + " if prepare_testdata:\n", + " normalise_data(Source_QC_folder, Target_QC_folder, Normalisation_QC_source, Normalisation_QC_target, path2im)\n" + ], + "metadata": { + "cellView": "form", + "id": "eHUa1opwWi9G" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "egpLYnYp1w3L", + "pycharm": { + "name": "#%%\n" + }, + "cellView": "form" + }, + "outputs": [], + "source": [ + "#@markdown ##Run Quality Control\n", + "if not avoid_evaluating_again:\n", + " # Here we count how many images are in our folder to be predicted and we had a few\n", + " Nb_files_Data_folder = len(os.listdir(os.path.join(path2im, \"A\", \"test\")))\n", + " #Nb_files_Data_folder = len(os.listdir(Source_QC_folder)) + 10\n", + "\n", + " # List images in Source_QC_folder\n", + " #Here we create a merged folder containing only imageA\n", + " # TODO: check if we can remove this chdir\n", + " os.chdir(pix2pix_code_dir)\n", + " imageA_folder = os.path.join(path2im, \"A\")\n", + " imageB_folder = os.path.join(path2im, \"B\")\n", + " imageAB_folder = os.path.join(path2im, \"AB\")\n", + " !python3 pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n", "\n", "if not patch_size_QC % 256 == 0:\n", - " patch_size_QC = ((int(patch_size_QC / 256)) * 256)\n", - " print (\" Your image dimensions are not divisible by 256; therefore your images have now been resized to:\",patch_size_QC)\n", + " patch_size_QC = ((int(patch_size_QC / 256)) * 256)\n", + " print (\" Your image dimensions are not divisible by 256; therefore your images have now been resized to:\",patch_size_QC)\n", "\n", "if patch_size_QC < 256:\n", - " patch_size_QC = 256\n", - "\n", - "Nb_Checkpoint = len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))\n", + " patch_size_QC = 256\n", "\n", - "print(Nb_Checkpoint)\n", + "# Nb_Checkpoint = len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))\n", + "Nb_Checkpoint = [int(s.split(\"_net_\")[0]) for s in os.listdir(full_QC_model_path) if s.__contains__('G.pth') and not s.__contains__('latest')]\n", + "Nb_Checkpoint.sort()\n", + "Nb_Checkpoint.append(\"latest\")\n", "\n", "## Initiate lists\n", - "\n", "Checkpoint_list = []\n", "Average_ssim_score_list = []\n", "Average_lpips_score_list = []\n", "\n", - "for j in range(1, len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))+1):\n", - " checkpoints = j*5\n", - "\n", - " if checkpoints == Nb_Checkpoint*5:\n", - " checkpoints = \"latest\"\n", - "\n", - " print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n", - "\n", - " Checkpoint_list.append(checkpoints)\n", - "\n", - " # Create a quality control/Prediction Folder\n", + "for j in range(0, len(Nb_Checkpoint) + 1, QC_freq):\n", + " if j >= len(Nb_Checkpoint):\n", + " checkpoints = \"latest\"\n", + " else:\n", + " checkpoints = Nb_Checkpoint[j] # j*QC_freq\n", "\n", - " QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n", + " #if checkpoints == len(Nb_Checkpoint): # Nb_Checkpoint*QC_freq:\n", + " # checkpoints = \"latest\"\n", "\n", - " if os.path.exists(QC_prediction_results):\n", - " shutil.rmtree(QC_prediction_results)\n", + " print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n", "\n", - " os.makedirs(QC_prediction_results)\n", + " Checkpoint_list.append(checkpoints)\n", "\n", - "#---------------------------- Predictions are performed here ----------------------\n", - " os.chdir(\"/content\")\n", - " !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$QC_model_name\" --model pix2pix --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $patch_size_QC --crop_size $patch_size_QC --results_dir \"$QC_prediction_results\" --checkpoints_dir \"$QC_model_path\" --direction AtoB --num_test $Nb_files_Data_folder --input_nc \"$nc\" --output_nc \"$nc\" --dataset_mode \"aligned\"\n", - "#-----------------------------------------------------------------------------------\n", + " # Create a quality control/Prediction Folder\n", + " QC_prediction_results = os.path.join(QC_model_path, QC_model_name, \"Quality Control\", str(checkpoints))\n", + " if not avoid_evaluating_again:\n", + " if os.path.exists(QC_prediction_results):\n", + " shutil.rmtree(QC_prediction_results)\n", + " os.makedirs(QC_prediction_results)\n", "\n", - "#Here we need to move the data again and remove all the unnecessary folders\n", + " #---------------------------- Predictions are performed here ----------------------\n", + " !python3 pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$QC_model_name\" --model pix2pix --epoch \"$checkpoints\" --no_dropout --preprocess scale_width --load_size \"$patch_size_QC\" --crop_size \"$patch_size_QC\" --results_dir \"$QC_prediction_results\" --checkpoints_dir \"$QC_model_path\" --direction AtoB --num_test \"$Nb_files_Data_folder\" --input_nc \"$nc\" --output_nc \"$nc\" --dataset_mode \"aligned\"\n", + " #-----------------------------------------------------------------------------------\n", "\n", - " Checkpoint_name = \"test_\"+str(checkpoints)\n", + " #Here we need to move the data again and remove all the unnecessary folders\n", "\n", - " QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n", + " Checkpoint_name = \"test_\"+str(checkpoints)\n", + " QC_results_images = os.path.join(QC_prediction_results, QC_model_name, Checkpoint_name, \"images\")\n", + " QC_results_images_files = os.listdir(QC_results_images)\n", "\n", - " QC_results_images_files = os.listdir(QC_results_images)\n", + " for f in QC_results_images_files:\n", + " shutil.copyfile(os.path.join(QC_results_images, f), os.path.join(QC_prediction_results, f))\n", "\n", - " for f in QC_results_images_files: \n", - " shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n", - "\n", - " os.chdir(\"/content\") \n", - "\n", - " #Here we clean up the extra files\n", - " shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n", + " #Here we clean up the extra files\n", + " shutil.rmtree(os.path.join(QC_prediction_results, QC_model_name))\n", "\n", " #-------------------------------- QC for RGB ------------------------------------\n", - " if number_channels == \"3\":\n", - "# List images in Source_QC_folder\n", - "# This will find the image dimension of a randomly choosen image in Source_QC_folder \n", - " random_choice = random.choice(os.listdir(Source_QC_folder))\n", - " x = io.imread(Source_QC_folder+\"/\"+random_choice)\n", - "\n", - " def ssim(img1, img2):\n", - " return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - " with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\" , \"Prediction v. GT lpips\", \"Input v. GT lpips\"])\n", - " \n", - " \n", - " # Initiate list\n", - " ssim_score_list = []\n", - " lpips_score_list = [] \n", - "\n", - "\n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - "\n", - " for i in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", - " print('Running QC on: '+i)\n", - "\n", - " shortname_no_PNG = i[:-4]\n", - " \n", - " # -------------------------------- Target test data (Ground truth) --------------------------------\n", - " \n", - " test_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\"))\n", - "\n", - " # -------------------------------- Source test data --------------------------------\n", - " test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\"))\n", - " \n", - " \n", - " # -------------------------------- Prediction --------------------------------\n", - " \n", - " test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n", - " \n", - " #--------------------------- Here we normalise using histograms matching--------------------------------\n", - " test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n", - " test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n", - " \n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n", - " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n", - "\n", - " ssim_score_list.append(index_SSIM_GTvsPrediction)\n", - "\n", - " #Save ssim_maps\n", - " img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n", - " img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n", + " if number_channels == \"3\":\n", "\n", - " # -------------------------------- Pearson correlation coefficient --------------------------------\n", + " Average_SSIM_checkpoint, Average_lpips_checkpoint = QC_RGB(Source_QC_folder, QC_prediction_results)\n", "\n", - "\n", - "\n", - "\n", - "\n", - " # -------------------------------- Calculate the perceptual difference metrics map and save them --------------------------------\n", - " if Do_lpips_analysis:\n", - "\n", - " lpips_GTvsPrediction = perceptual_diff(test_GT, test_prediction, 'alex', True)\n", - " lpips_GTvsPrediction_image = lpips_GTvsPrediction[0,0,...].data.cpu().numpy()\n", - " lpips_GTvsPrediction_score= lpips_GTvsPrediction.mean().data.numpy()\n", - " lpips_score_list.append(lpips_GTvsPrediction_score)\n", - "\n", - "\n", - " lpips_GTvsSource = perceptual_diff(test_GT, test_source, 'alex', True)\n", - " lpips_GTvsSource_image = lpips_GTvsSource[0,0,...].data.cpu().numpy()\n", - " lpips_GTvsSource_score= lpips_GTvsSource.mean().data.numpy()\n", - "\n", - "\n", - " #lpips_GTvsPrediction_image_8bit = (lpips_GTvsPrediction_image* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsPrediction_\"+shortname_no_PNG+'.tif',lpips_GTvsPrediction_image)\n", - "\n", - " #lpips_GTvsSource_image_8bit = (lpips_GTvsSource_image* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsInput_\"+shortname_no_PNG+'.tif',lpips_GTvsSource_image)\n", - " else:\n", - " lpips_GTvsPrediction_score = 0\n", - " lpips_score_list.append(lpips_GTvsPrediction_score)\n", - "\n", - " lpips_GTvsSource_score = 0\n", - "\n", - "\n", - " \n", - " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource), str(lpips_GTvsPrediction_score),str(lpips_GTvsSource_score)])\n", - "\n", - " #Here we calculate the ssim average for each image in each checkpoints\n", - "\n", - " Average_SSIM_checkpoint = Average(ssim_score_list)\n", " Average_ssim_score_list.append(Average_SSIM_checkpoint)\n", - "\n", - " Average_lpips_checkpoint = Average(lpips_score_list)\n", " Average_lpips_score_list.append(Average_lpips_checkpoint)\n", - "\n", "#------------------------------------------- QC for Grayscale ----------------------------------------------\n", - "\n", - " if number_channels == \"1\":\n", - " def ssim(img1, img2):\n", - " return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n", - "\n", - "\n", - " def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", - "\n", - "\n", - " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", - " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", - " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", - "\n", - "\n", - " def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n", - " \n", - " if dtype is not None:\n", - " x = x.astype(dtype,copy=False)\n", - " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", - " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", - " eps = dtype(eps)\n", - "\n", - " try:\n", - " import numexpr\n", - " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", - " except ImportError:\n", - " x = (x - mi) / ( ma - mi + eps )\n", - "\n", - " if clip:\n", - " x = np.clip(x,0,1)\n", - "\n", - " return x\n", - "\n", - " def norm_minmse(gt, x, normalize_gt=True):\n", - " \n", - " if normalize_gt:\n", - " gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n", - " x = x.astype(np.float32, copy=False) - np.mean(x)\n", - " #x = x - np.mean(x)\n", - " gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n", - " #gt = gt - np.mean(gt)\n", - " scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n", - " return gt, scale * x\n", - "\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - " with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\", \"Prediction v. GT lpips\", \"Input v. GT lpips\"]) \n", - "\n", - " # Initialize the lists\n", - " ssim_score_list = []\n", - " Pearson_correlation_coefficient_list = []\n", - " lpips_score_list = []\n", - " \n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - " for i in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", - " print('Running QC on: '+i)\n", - "\n", - " shortname_no_PNG = i[:-4]\n", - " # -------------------------------- Target test data (Ground truth) --------------------------------\n", - " test_GT_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\")) \n", - " test_GT = test_GT_raw[:,:,2]\n", - "\n", - " # -------------------------------- Source test data --------------------------------\n", - " test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\")) \n", - " test_source = test_source_raw[:,:,2]\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n", - " test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n", - "\n", - " # -------------------------------- Prediction --------------------------------\n", - " test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n", - " \n", - " test_prediction = test_prediction_raw[:,:,2]\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n", - " test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n", - "\n", - "\n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n", - " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n", - "\n", - " ssim_score_list.append(index_SSIM_GTvsPrediction)\n", - "\n", - " #Save ssim_maps\n", - " \n", - " img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n", - " img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n", - " \n", - " # Calculate the Root Squared Error (RSE) maps\n", - " img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n", - " img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n", - "\n", - " # Save SE maps\n", - " img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n", - " img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n", - "\n", - "\n", - " # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n", - "\n", - " # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n", - " NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n", - " NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n", - " \n", - " # We can also measure the peak signal to noise ratio between the images\n", - " PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n", - " PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n", - "\n", - "\n", - " \n", - " # -------------------------------- Calculate the perceptual difference metrics map and save them --------------------------------\n", - " if Do_lpips_analysis:\n", - " lpips_GTvsPrediction = perceptual_diff(test_GT_raw, test_prediction_raw, 'alex', True)\n", - " lpips_GTvsPrediction_image = lpips_GTvsPrediction[0,0,...].data.cpu().numpy()\n", - " lpips_GTvsPrediction_score= lpips_GTvsPrediction.mean().data.numpy()\n", - " lpips_score_list.append(lpips_GTvsPrediction_score)\n", - "\n", - " lpips_GTvsSource = perceptual_diff(test_GT_raw, test_source_raw, 'alex', True)\n", - " lpips_GTvsSource_image = lpips_GTvsSource[0,0,...].data.cpu().numpy()\n", - " lpips_GTvsSource_score= lpips_GTvsSource.mean().data.numpy()\n", - "\n", - "\n", - " lpips_GTvsPrediction_image_8bit = (lpips_GTvsPrediction_image* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsPrediction_\"+shortname_no_PNG+'.tif',lpips_GTvsPrediction_image_8bit)\n", - "\n", - " lpips_GTvsSource_image_8bit = (lpips_GTvsSource_image* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsInput_\"+shortname_no_PNG+'.tif',lpips_GTvsSource_image_8bit)\n", - " else:\n", - " lpips_GTvsPrediction_score = 0\n", - " lpips_score_list.append(lpips_GTvsPrediction_score)\n", - "\n", - " lpips_GTvsSource_score = 0\n", - "\n", - "\n", - " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource),str(lpips_GTvsPrediction_score),str(lpips_GTvsSource_score)])\n", - "\n", - "\n", - " #Here we calculate the ssim average for each image in each checkpoints\n", - "\n", - " Average_SSIM_checkpoint = Average(ssim_score_list)\n", + " if number_channels == \"1\":\n", + " Average_SSIM_checkpoint, Average_lpips_checkpoint = QC_singlechannel(Source_QC_folder, QC_prediction_results)\n", " Average_ssim_score_list.append(Average_SSIM_checkpoint)\n", - "\n", - " Average_lpips_checkpoint = Average(lpips_score_list)\n", " Average_lpips_score_list.append(Average_lpips_checkpoint)\n", "\n", - "\n", - "# All data is now processed saved\n", - " \n", - "\n", + " # All data is now processed saved\n", "# -------------------------------- Display --------------------------------\n", "\n", "# Display the IoV vs Checkpoint plot\n", @@ -2223,359 +2139,342 @@ "plt.ylabel('SSIM')\n", "plt.xlabel('Checkpoints')\n", "plt.legend()\n", - "plt.savefig(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n", + "plt.savefig(os.path.join(full_QC_model_path,'Quality Control','SSIMvsCheckpoint_data.png'),bbox_inches='tight',pad_inches=0)\n", "plt.show()\n", "\n", "\n", "# -------------------------------- Display --------------------------------\n", "\n", "if Do_lpips_analysis:\n", - " # Display the lpips vs Checkpoint plot\n", - " plt.figure(figsize=(20,5))\n", - " plt.plot(Checkpoint_list, Average_lpips_score_list, label=\"lpips\")\n", - " plt.title('Checkpoints vs. lpips')\n", - " plt.ylabel('lpips')\n", - " plt.xlabel('Checkpoints')\n", - " plt.legend()\n", - " plt.savefig(full_QC_model_path+'/Quality Control/lpipsvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n", - " plt.show()\n", - "\n", - "\n", + " # Display the lpips vs Checkpoint plot\n", + " plt.figure(figsize=(20,5))\n", + " plt.plot(Checkpoint_list, Average_lpips_score_list, label=\"lpips\")\n", + " plt.title('Checkpoints vs. lpips')\n", + " plt.ylabel('lpips')\n", + " plt.xlabel('Checkpoints')\n", + " plt.legend()\n", + " plt.savefig(os.path.join(full_QC_model_path,'Quality Control','lpipsvsCheckpoint_data.png'),bbox_inches='tight',pad_inches=0)\n", + " plt.show()\n", "\n", "# -------------------------------- Display RGB --------------------------------\n", "\n", + "\n", "from ipywidgets import interact\n", "import ipywidgets as widgets\n", - "\n", - "\n", "if number_channels == \"3\":\n", - " random_choice_shortname_no_PNG = shortname_no_PNG\n", - "\n", - " @interact\n", - " def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n", - "\n", - " random_choice_shortname_no_PNG = file[:-4]\n", - "\n", - " df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n", - " df2 = df1.set_index(\"image #\", drop = False)\n", - " index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n", - " index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n", - " lpips_GTvsPrediction = df2.loc[file, \"Prediction v. GT lpips\"]\n", - " lpips_GTvsSource = df2.loc[file, \"Input v. GT lpips\"]\n", - "\n", - "#Setting up colours\n", - " cmap = None\n", - "\n", - " plt.figure(figsize=(15,15))\n", - "\n", - "# Target (Ground-truth)\n", - " plt.subplot(3,3,1)\n", - " plt.axis('off')\n", - " img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"), as_gray=False, pilmode=\"RGB\")\n", - " \n", - " plt.imshow(img_GT, cmap = cmap)\n", - " plt.title('Target',fontsize=15)\n", - "\n", - "# Source\n", - " plt.subplot(3,3,2)\n", - " plt.axis('off')\n", - " img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"), as_gray=False, pilmode=\"RGB\")\n", - " plt.imshow(img_Source, cmap = cmap)\n", - " plt.title('Source',fontsize=15)\n", - "\n", - "#Prediction\n", - " plt.subplot(3,3,3)\n", - " plt.axis('off')\n", - "\n", - " img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n", - "\n", - " plt.imshow(img_Prediction, cmap = cmap)\n", - " plt.title('Prediction',fontsize=15)\n", - "\n", - "\n", - "#SSIM between GT and Source\n", - " plt.subplot(3,3,5)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - "#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Source',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", - " plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#SSIM between GT and Prediction\n", - " plt.subplot(3,3,6)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "\n", - " img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - "#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "#lpips Error between GT and source\n", - "\n", - " if Do_lpips_analysis:\n", - " plt.subplot(3,3,8)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_lpips_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsInput_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " imlpips_GTvsSource = plt.imshow(img_lpips_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imlpips_GTvsSource,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Input',fontsize=15)\n", - " plt.xlabel('lpips: '+str(round(lpips_GTvsSource,3)),fontsize=14)\n", - " plt.ylabel('Lpips maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "\n", - " #lpips Error between GT and Prediction\n", - " plt.subplot(3,3,9)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_lpips_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " imlpips_GTvsPrediction = plt.imshow(img_lpips_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imlpips_GTvsPrediction,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('lpips: '+str(round(lpips_GTvsPrediction,3)),fontsize=14)\n", - "\n", - " plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", + " @interact\n", + " def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n", + "\n", + " random_choice_shortname_no_PNG = file[:-4]\n", + " df1 = pd.read_csv(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints),\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\"), header=0)\n", + " df2 = df1.set_index(\"image #\", drop = False)\n", + " index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n", + " index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n", + " lpips_GTvsPrediction = df2.loc[file, \"Prediction v. GT lpips\"]\n", + " lpips_GTvsSource = df2.loc[file, \"Input v. GT lpips\"]\n", + "\n", + " #Setting up colours\n", + " cmap = None\n", + " plt.figure(figsize=(15,15))\n", + "\n", + " # Target (Ground-truth)\n", + " plt.subplot(3,3,1)\n", + " plt.axis('off')\n", + " img_GT = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"), as_gray=False, pilmode=\"RGB\")\n", + " plt.imshow(img_GT, cmap = cmap)\n", + " plt.title('Target',fontsize=15)\n", + "\n", + " # Source\n", + " plt.subplot(3,3,2)\n", + " plt.axis('off')\n", + " img_Source = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"), as_gray=False, pilmode=\"RGB\")\n", + " plt.imshow(img_Source, cmap = cmap)\n", + " plt.title('Source',fontsize=15)\n", + "\n", + " # Prediction\n", + " plt.subplot(3,3,3)\n", + " plt.axis('off')\n", + " img_Prediction = io.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n", + " plt.imshow(img_Prediction, cmap = cmap)\n", + " plt.title('Prediction',fontsize=15)\n", + "\n", + " # SSIM between GT and Source\n", + " plt.subplot(3,3,5)\n", + " # plt.axis('off')\n", + " plt.tick_params(\n", + " axis='both', # changes apply to the x-axis and y-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " left=False, # ticks along the left edge are off\n", + " right=False, # ticks along the right edge are off\n", + " labelbottom=False,\n", + " labelleft=False)\n", + "\n", + " img_SSIM_GTvsSource = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", + "\n", + " imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", + " # plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", + " plt.title('Target vs. Source',fontsize=15)\n", + " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", + " plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", + "\n", + " # SSIM between GT and Prediction\n", + " plt.subplot(3,3,6)\n", + " # plt.axis('off')\n", + " plt.tick_params(\n", + " axis='both', # changes apply to the x-axis and y-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " left=False, # ticks along the left edge are off\n", + " right=False, # ticks along the right edge are off\n", + " labelbottom=False,\n", + " labelleft=False)\n", + "\n", + " img_SSIM_GTvsPrediction = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", + "\n", + " imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", + " # plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", + " plt.title('Target vs. Prediction',fontsize=15)\n", + " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", + "\n", + " # lpips Error between GT and source\n", + "\n", + " if Do_lpips_analysis:\n", + " plt.subplot(3,3,8)\n", + " # plt.axis('off')\n", + " plt.tick_params(\n", + " axis='both', # changes apply to the x-axis and y-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " left=False, # ticks along the left edge are off\n", + " right=False, # ticks along the right edge are off\n", + " labelbottom=False,\n", + " labelleft=False)\n", + "\n", + " img_lpips_GTvsSource = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), \"lpips_GTvsInput_\"+random_choice_shortname_no_PNG+\".tif\"))\n", + "\n", + " imlpips_GTvsSource = plt.imshow(img_lpips_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", + " plt.colorbar(imlpips_GTvsSource,fraction=0.046,pad=0.04)\n", + " plt.title('Target vs. Source',fontsize=15)\n", + " plt.xlabel('lpips: '+str(round(lpips_GTvsSource,3)),fontsize=14)\n", + " plt.ylabel('Lpips maps',fontsize=20, rotation=0, labelpad=75)\n", + "\n", + "\n", + " # lpips Error between GT and Prediction\n", + " plt.subplot(3,3,9)\n", + " # plt.axis('off')\n", + " plt.tick_params(\n", + " axis='both', # changes apply to the x-axis and y-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " left=False, # ticks along the left edge are off\n", + " right=False, # ticks along the right edge are off\n", + " labelbottom=False,\n", + " labelleft=False)\n", + "\n", + " img_lpips_GTvsPrediction = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), \"lpips_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", + "\n", + " imlpips_GTvsPrediction = plt.imshow(img_lpips_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", + " plt.colorbar(imlpips_GTvsPrediction,fraction=0.046,pad=0.04)\n", + " plt.title('Target vs. Prediction',fontsize=15)\n", + " plt.xlabel('lpips: '+str(round(lpips_GTvsPrediction,3)),fontsize=14)\n", + "\n", + " plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'QC_example_data.png'),bbox_inches='tight',pad_inches=0)\n", "\n", "# -------------------------------- Display Grayscale --------------------------------\n", "\n", "if number_channels == \"1\":\n", - " random_choice_shortname_no_PNG = shortname_no_PNG\n", - "\n", - " @interact\n", - " def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n", - "\n", - " random_choice_shortname_no_PNG = file[:-4]\n", - "\n", - " df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n", - " df2 = df1.set_index(\"image #\", drop = False)\n", - " index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n", - " index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n", - "\n", - " NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n", - " NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n", - " PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n", - " PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n", - " lpips_GTvsPrediction = df2.loc[file, \"Prediction v. GT lpips\"]\n", - " lpips_GTvsSource = df2.loc[file, \"Input v. GT lpips\"]\n", - "\n", - " plt.figure(figsize=(20,20))\n", - " # Currently only displays the last computed set, from memory\n", - " # Target (Ground-truth)\n", - " plt.subplot(4,3,1)\n", - " plt.axis('off')\n", - " img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"))\n", - "\n", - " plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n", - " plt.title('Target',fontsize=15)\n", - "\n", - "# Source\n", - " plt.subplot(4,3,2)\n", - " plt.axis('off')\n", - " img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"))\n", - " plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n", - " plt.title('Source',fontsize=15)\n", - "\n", - "#Prediction\n", - " plt.subplot(4,3,3)\n", - " plt.axis('off')\n", - " img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n", - " plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n", - " plt.title('Prediction',fontsize=15)\n", - "\n", - "#Setting up colours\n", - " cmap = plt.cm.CMRmap\n", - "\n", - "#SSIM between GT and Source\n", - " plt.subplot(4,3,5)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - " img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n", - " imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - " \n", - " plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Source',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", - " plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#SSIM between GT and Prediction\n", - " plt.subplot(4,3,6)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - " \n", - " img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - " img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n", - " imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - " \n", - " plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "#Root Squared Error between GT and Source\n", - " plt.subplot(4,3,8)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - " img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n", - "\n", - " imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n", - " plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Source',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n", - "#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n", - " plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#Root Squared Error between GT and Prediction\n", - " plt.subplot(4,3,9)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n", - "\n", - " imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "#lpips Error between GT and source\n", - "\n", - " if Do_lpips_analysis:\n", - " plt.subplot(4,3,11)\n", - "\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_lpips_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsInput_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " img_lpips_GTvsSource = img_lpips_GTvsSource / 255\n", - "\n", - " imlpips_GTvsSource = plt.imshow(img_lpips_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imlpips_GTvsSource,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Input',fontsize=15)\n", - " plt.xlabel('lpips: '+str(round(lpips_GTvsSource,3)),fontsize=14)\n", - " plt.ylabel('Lpips maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#lpips Error between GT and Prediction\n", - " plt.subplot(4,3,12)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_lpips_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " img_lpips_GTvsPrediction = img_lpips_GTvsPrediction / 255\n", - "\n", - " imlpips_GTvsPrediction = plt.imshow(img_lpips_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imlpips_GTvsPrediction,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('lpips: '+str(round(lpips_GTvsPrediction,3)),fontsize=14)\n", - "\n", - " plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", + " @interact\n", + " def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n", + " random_choice_shortname_no_PNG = file[:-4]\n", + "\n", + " df1 = pd.read_csv(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), \"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\"), header=0)\n", + " df2 = df1.set_index(\"image #\", drop = False)\n", + " index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n", + " index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n", + "\n", + " NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n", + " NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n", + " PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n", + " PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n", + " lpips_GTvsPrediction = df2.loc[file, \"Prediction v. GT lpips\"]\n", + " lpips_GTvsSource = df2.loc[file, \"Input v. GT lpips\"]\n", + "\n", + " plt.figure(figsize=(20,20))\n", + " # Currently only displays the last computed set, from memory\n", + " # Target (Ground-truth)\n", + " plt.subplot(4,3,1)\n", + " plt.axis('off')\n", + " img_GT = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"))\n", + "\n", + " plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n", + " plt.title('Target',fontsize=15)\n", + "\n", + " # Source\n", + " plt.subplot(4,3,2)\n", + " plt.axis('off')\n", + " img_Source = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"))\n", + " plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n", + " plt.title('Source',fontsize=15)\n", + "\n", + " #Prediction\n", + " plt.subplot(4,3,3)\n", + " plt.axis('off')\n", + " img_Prediction = io.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n", + " plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n", + " plt.title('Prediction',fontsize=15)\n", + "\n", + " #Setting up colours\n", + " cmap = plt.cm.CMRmap\n", + "\n", + " #SSIM between GT and Source\n", + " plt.subplot(4,3,5)\n", + " #plt.axis('off')\n", + " plt.tick_params(\n", + " axis='both', # changes apply to the x-axis and y-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " left=False, # ticks along the left edge are off\n", + " right=False, # ticks along the right edge are off\n", + " labelbottom=False,\n", + " labelleft=False)\n", + "\n", + " img_SSIM_GTvsSource = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", + " img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n", + " imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", + "\n", + " plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", + " plt.title('Target vs. Source',fontsize=15)\n", + " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", + " plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", + "\n", + " #SSIM between GT and Prediction\n", + " plt.subplot(4,3,6)\n", + " #plt.axis('off')\n", + " plt.tick_params(\n", + " axis='both', # changes apply to the x-axis and y-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " left=False, # ticks along the left edge are off\n", + " right=False, # ticks along the right edge are off\n", + " labelbottom=False,\n", + " labelleft=False)\n", + "\n", + " img_SSIM_GTvsPrediction = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", + " img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n", + " imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", + "\n", + " plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", + " plt.title('Target vs. Prediction',fontsize=15)\n", + " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", + "\n", + " #Root Squared Error between GT and Source\n", + " plt.subplot(4,3,8)\n", + " #plt.axis('off')\n", + " plt.tick_params(\n", + " axis='both', # changes apply to the x-axis and y-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " left=False, # ticks along the left edge are off\n", + " right=False, # ticks along the right edge are off\n", + " labelbottom=False,\n", + " labelleft=False)\n", + "\n", + " img_RSE_GTvsSource = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", + " img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n", + "\n", + " imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n", + " plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n", + " plt.title('Target vs. Source',fontsize=15)\n", + " plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n", + " #plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n", + " plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n", + "\n", + " #Root Squared Error between GT and Prediction\n", + " plt.subplot(4,3,9)\n", + " #plt.axis('off')\n", + " plt.tick_params(\n", + " axis='both', # changes apply to the x-axis and y-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " left=False, # ticks along the left edge are off\n", + " right=False, # ticks along the right edge are off\n", + " labelbottom=False,\n", + " labelleft=False)\n", + "\n", + " img_RSE_GTvsPrediction = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", + "\n", + " img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n", + "\n", + " imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", + " plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n", + " plt.title('Target vs. Prediction',fontsize=15)\n", + " plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n", + "\n", + " #lpips Error between GT and source\n", + "\n", + " if Do_lpips_analysis:\n", + " plt.subplot(4,3,11)\n", + "\n", + " plt.tick_params(\n", + " axis='both', # changes apply to the x-axis and y-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " left=False, # ticks along the left edge are off\n", + " right=False, # ticks along the right edge are off\n", + " labelbottom=False,\n", + " labelleft=False)\n", + "\n", + " img_lpips_GTvsSource = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), \"lpips_GTvsInput_\"+random_choice_shortname_no_PNG+\".tif\"))\n", + "\n", + " img_lpips_GTvsSource = img_lpips_GTvsSource / 255\n", + "\n", + " imlpips_GTvsSource = plt.imshow(img_lpips_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", + " plt.colorbar(imlpips_GTvsSource,fraction=0.046,pad=0.04)\n", + " plt.title('Target vs. Source',fontsize=15)\n", + " plt.xlabel('lpips: '+str(round(lpips_GTvsSource,3)),fontsize=14)\n", + " plt.ylabel('Lpips maps',fontsize=20, rotation=0, labelpad=75)\n", + "\n", + " #lpips Error between GT and Prediction\n", + " plt.subplot(4,3,12)\n", + " #plt.axis('off')\n", + " plt.tick_params(\n", + " axis='both', # changes apply to the x-axis and y-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " left=False, # ticks along the left edge are off\n", + " right=False, # ticks along the right edge are off\n", + " labelbottom=False,\n", + " labelleft=False)\n", + "\n", + " img_lpips_GTvsPrediction = imageio.imread(os.path.join(full_QC_model_path,'Quality Control', str(checkpoints), \"lpips_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", + "\n", + " img_lpips_GTvsPrediction = img_lpips_GTvsPrediction / 255\n", + "\n", + " imlpips_GTvsPrediction = plt.imshow(img_lpips_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", + " plt.colorbar(imlpips_GTvsPrediction,fraction=0.046,pad=0.04)\n", + " plt.title('Target vs. Prediction',fontsize=15)\n", + " plt.xlabel('lpips: '+str(round(lpips_GTvsPrediction,3)),fontsize=14)\n", + "\n", + " plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'QC_example_data.png'),bbox_inches='tight',pad_inches=0)\n", "\n", "#Make a pdf summary of the QC results\n", - "\n", - "qc_pdf_export()\n" + "qc_pdf_export()" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "GTN0p7BY1w3O", "pycharm": { "name": "#%% md\n" @@ -2590,10 +2489,8 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { - "collapsed": false, "id": "ju5OiN0h1w3P", "pycharm": { "name": "#%% md\n" @@ -2603,7 +2500,7 @@ "## **6.1. Generate prediction(s) from unseen dataset**\n", "---\n", "\n", - "The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n", + "The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** as PNG images, under the model name, in a subfolder called `results_`.\n", "\n", "**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n", "\n", @@ -2626,7 +2523,6 @@ "source": [ "#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n", "import glob\n", - "import os.path\n", "\n", "latest = \"latest\"\n", "\n", @@ -2645,7 +2541,7 @@ "\n", "#@markdown ###Image normalisation:\n", "\n", - "Normalisation_prediction_source = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", + "Normalisation_prediction_source = \"Contrast stretching\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", "\n", "# model name and path\n", "#@markdown ###Do you want to use the current trained model?\n", @@ -2668,10 +2564,15 @@ "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", "\n", "#here we check if we use the newly trained network or not\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", + "if (Use_the_current_trained_model):\n", + " try:\n", + " print(\"Using current trained network\")\n", + " Prediction_model_name = model_name\n", + " Prediction_model_path = model_path\n", + " except:\n", + " print(\"Using current tested network in the QC\")\n", + " Prediction_model_name = QC_model_name\n", + " Prediction_model_path = QC_model_path\n", "\n", "if not patch_size % 256 == 0:\n", " patch_size = ((int(patch_size / 256)) * 256)\n", @@ -2701,7 +2602,7 @@ " if not checkpoint % 5 == 0:\n", " checkpoint = ((int(checkpoint / 5)-1) * 5)\n", " print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n", - " \n", + "\n", " if checkpoint == Nb_Checkpoint*5:\n", " checkpoint = \"latest\"\n", "\n", @@ -2710,36 +2611,35 @@ "\n", "# Here we need to move the data to be analysed so that pix2pix can find them\n", "\n", - "Saving_path_prediction= \"/content/\"+Prediction_model_name\n", + "Saving_path_prediction= os.path.join(pix2pix_working_directory, Prediction_model_name)\n", "\n", "if os.path.exists(Saving_path_prediction):\n", " shutil.rmtree(Saving_path_prediction)\n", "os.makedirs(Saving_path_prediction)\n", "\n", - "imageA_folder = Saving_path_prediction+\"/A\"\n", + "imageA_folder = os.path.join(Saving_path_prediction, \"A\")\n", "os.makedirs(imageA_folder)\n", "\n", - "imageB_folder = Saving_path_prediction+\"/B\"\n", + "imageB_folder = os.path.join(Saving_path_prediction, \"B\")\n", "os.makedirs(imageB_folder)\n", "\n", - "imageAB_folder = Saving_path_prediction+\"/AB\"\n", + "imageAB_folder = os.path.join(Saving_path_prediction, \"AB\")\n", "os.makedirs(imageAB_folder)\n", "\n", - "testAB_Folder = Saving_path_prediction+\"/AB/test\"\n", + "testAB_Folder = os.path.join(imageAB_folder, \"test\")\n", "os.makedirs(testAB_Folder)\n", "\n", - "testA_Folder = Saving_path_prediction+\"/A/test\"\n", + "testA_Folder = os.path.join(imageA_folder, \"test\")\n", "os.makedirs(testA_Folder)\n", - " \n", - "testB_Folder = Saving_path_prediction+\"/B/test\"\n", + "\n", + "testB_Folder = os.path.join(imageB_folder, \"test\")\n", "os.makedirs(testB_Folder)\n", "\n", "#Here we copy and normalise the data\n", "\n", "if Normalisation_prediction_source == \"Contrast stretching\":\n", - " \n", - " for filename in os.listdir(Data_folder):\n", "\n", + " for filename in os.listdir(Data_folder):\n", " img = io.imread(os.path.join(Data_folder,filename)).astype(np.float32)\n", " short_name = os.path.splitext(filename)\n", "\n", @@ -2748,9 +2648,9 @@ "\n", " img = 255 * img # Now scale by 255\n", " img = img.astype(np.uint8)\n", - " cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n", - " cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n", - " \n", + " cv2.imwrite(os.path.join(testA_Folder, short_name[0]+\".png\"), img)\n", + " cv2.imwrite(os.path.join(testB_Folder, short_name[0]+\".png\"), img)\n", + "\n", "if Normalisation_prediction_source == \"Adaptive Equalization\":\n", "\n", " for filename in os.listdir(Data_folder):\n", @@ -2763,24 +2663,25 @@ " img = 255 * img # Now scale by 255\n", " img = img.astype(np.uint8)\n", "\n", - " cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n", - " cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n", + " cv2.imwrite(os.path.join(testA_Folder, short_name[0]+\".png\"), img)\n", + " cv2.imwrite(os.path.join(testB_Folder, short_name[0]+\".png\"), img)\n", "\n", "if Normalisation_prediction_source == \"None\":\n", " for files in os.listdir(Data_folder):\n", - " shutil.copyfile(Data_folder+\"/\"+files, testA_Folder+\"/\"+files)\n", - " shutil.copyfile(Data_folder+\"/\"+files, testB_Folder+\"/\"+files)\n", - " \n", + " shutil.copyfile(os.path.join(Data_folder, files, testA_Folder, files))\n", + " shutil.copyfile(os.path.join(Data_folder, files, testB_Folder, files))\n", + "\n", "# Here we create a merged A / A image for the prediction\n", - "os.chdir(\"/content\")\n", - "!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n", + "# TODO: check if we can remove this chdir\n", + "os.chdir(pix2pix_code_dir)\n", + "!python3 pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n", "\n", "# Here we count how many images are in our folder to be predicted and we had a few\n", "Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n", "\n", - "# This will find the image dimension of a randomly choosen image in Data_folder \n", + "# This will find the image dimension of a randomly choosen image in Data_folder\n", "random_choice = random.choice(os.listdir(Data_folder))\n", - "x = imageio.imread(Data_folder+\"/\"+random_choice)\n", + "x = imageio.imread(os.path.join(Data_folder, random_choice))\n", "\n", "#Find image XY dimension\n", "Image_Y = x.shape[0]\n", @@ -2812,7 +2713,7 @@ " #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n", " #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n", " #('--no_dropout', action='store_true', help='no dropout for the generator')\n", - " \n", + "\n", "# dataset parameters\n", " #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n", " #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n", @@ -2825,13 +2726,13 @@ " #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n", " #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n", " #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n", - " \n", + "\n", "# additional parameters\n", " #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n", " #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n", " #('--verbose', action='store_true', help='if specified, print more debugging information')\n", " #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n", - " \n", + "\n", "\n", " #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n", " #('--results_dir', type=str, default='./results/', help='saves results here.')\n", @@ -2842,7 +2743,7 @@ " #('--eval', action='store_true', help='use eval mode during test time.')\n", " #('--num_test', type=int, default=50, help='how many test images to run')\n", " # rewrite devalue values\n", - " \n", + "\n", "# To avoid cropping, the load_size should be the same as crop_size\n", " #parser.set_defaults(load_size=parser.get_default('crop_size'))\n", "\n", @@ -2850,23 +2751,23 @@ "\n", "\n", "#---------------------------- Predictions are performed here ----------------------\n", + "# TODO: check if we can remove this chdir\n", + "os.chdir(pix2pix_code_dir)\n", "\n", - "os.chdir(\"/content\")\n", - "\n", - "!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$Prediction_model_name\" --model pix2pix --no_dropout --preprocess scale_width --load_size $patch_size --crop_size $patch_size --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint --input_nc \"$nc\" --output_nc \"$nc\" --dataset_mode \"aligned\"\n", + "!python3 pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$Prediction_model_name\" --model pix2pix --no_dropout --preprocess scale_width --load_size $patch_size --crop_size $patch_size --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint --input_nc \"$nc\" --output_nc \"$nc\" --dataset_mode \"aligned\"\n", "\n", "#-----------------------------------------------------------------------------------\n", "\n", "\n", - "Checkpoint_name = \"test_\"+str(checkpoint)\n", + "Checkpoint_name = \"results_\"+str(checkpoint)\n", "\n", "\n", - "Prediction_results_folder = Result_folder+\"/\"+Prediction_model_name+\"/\"+Checkpoint_name+\"/images\"\n", + "Prediction_results_folder = os.path.join(Result_folder, Prediction_model_name, Checkpoint_name, \"images\")\n", "\n", "Prediction_results_images = os.listdir(Prediction_results_folder)\n", "\n", - "for f in Prediction_results_images: \n", - " if (f.endswith(\"_real_B.png\")): \n", + "for f in Prediction_results_images:\n", + " if (f.endswith(\"_real_B.png\")):\n", " os.remove(Prediction_results_folder+\"/\"+f)\n", "\n", "\n", @@ -2875,7 +2776,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Pdnb77E15zLE" @@ -2899,15 +2799,11 @@ "import os\n", "# This will display a randomly chosen dataset input and predicted output\n", "random_choice = random.choice(os.listdir(Data_folder))\n", - "\n", - "\n", "random_choice_no_extension = os.path.splitext(random_choice)\n", "\n", - "\n", - "x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real_A.png\")\n", - "\n", - "\n", - "y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake_B.png\")\n", + "results_path_test = os.path.join(Result_folder, Prediction_model_name, \"test_\"+str(checkpoint), \"images\")\n", + "x = imageio.imread(os.path.join(results_path_test, random_choice_no_extension[0]+\"_real_A.png\"))\n", + "y = imageio.imread(os.path.join(results_path_test,random_choice_no_extension[0]+\"_fake_B.png\"))\n", "\n", "f=plt.figure(figsize=(16,8))\n", "plt.subplot(1,2,1)\n", @@ -2922,7 +2818,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "hvkd66PldsXB" @@ -2935,7 +2830,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "HD0yZaIhUhth" @@ -2943,7 +2837,12 @@ "source": [ "# **7. Version log**\n", "---\n", - "**v1.15.1**: \n", + "**v1.16.1**:\n", + "* Removes references to ?/content/'\n", + "* Uses predefined functions for the new images to make the code more readable\n", + "\n", + "\n", + "**v1.15.1**:\n", "* Many bug fixes by **Johanna Rahm**\n", "* Number of channels\n", "\n", @@ -2963,12 +2862,12 @@ "\n", "\n", "\n", - "**v1.13**: \n", + "**v1.13**:\n", "\n", "\n", "\n", "* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n", - "This version also now includes built-in version check and the version log that \n", + "This version also now includes built-in version check and the version log that\n", "\n", "* This version also now includes built-in version check and the version log that you're reading now.\n", "\n", @@ -2979,7 +2878,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "id": "UvSlTaH14s3t" @@ -2994,11 +2892,11 @@ "accelerator": "GPU", "colab": { "machine_shape": "hm", - "name": "pix2pix_ZeroCostDL4Mic.ipynb", "provenance": [] }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" }, "language_info": { @@ -3011,9 +2909,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.7" + "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file
Parameter