From 0c4e7e315c7e30a5ef3b429087c949ad3e1ea817 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Sep 2024 16:40:39 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- notebooks/demo_denoise_mode.ipynb | 90 ++++++------- notebooks/example.ipynb | 168 +++++++++++++------------ notebooks/example_ms.ipynb | 18 +-- notebooks/stepbystep_Makinen2020.ipynb | 94 +++++++------- study/paramter_tuning/view_rst.ipynb | 60 ++++----- 5 files changed, 222 insertions(+), 208 deletions(-) diff --git a/notebooks/demo_denoise_mode.ipynb b/notebooks/demo_denoise_mode.ipynb index 9bc160b..c3ea061 100644 --- a/notebooks/demo_denoise_mode.ipynb +++ b/notebooks/demo_denoise_mode.ipynb @@ -43,15 +43,15 @@ "import matplotlib.pyplot as plt\n", "\n", "# Set the GPU device ID to 0 for this notebook session\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "\n", "this_dir = os.path.abspath(\"\")\n", - "data_dir = os.path.join(this_dir, '../tests/bm3dornl-data')\n", - "datafile = os.path.join(data_dir, 'tomostack_small.h5')\n", + "data_dir = os.path.join(this_dir, \"../tests/bm3dornl-data\")\n", + "datafile = os.path.join(data_dir, \"tomostack_small.h5\")\n", "\n", "# Load the data and select a noisy sinogram\n", - "with h5py.File(datafile, 'r') as f:\n", - " tomo_stack_noisy = f['noisy_tomostack'][:]\n", + "with h5py.File(datafile, \"r\") as f:\n", + " tomo_stack_noisy = f[\"noisy_tomostack\"][:]\n", "\n", "# Select a sinogram with low SNR\n", "sino_low_snr = tomo_stack_noisy[:, 10, :]\n", @@ -62,13 +62,13 @@ "# Plot the sinograms\n", "plt.figure(figsize=(15, 5))\n", "plt.subplot(121)\n", - "plt.imshow(sino_low_snr, cmap='gray')\n", + "plt.imshow(sino_low_snr, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Sinogram with low SNR')\n", + "plt.title(\"Sinogram with low SNR\")\n", "plt.subplot(122)\n", - "plt.imshow(sino_high_snr, cmap='gray')\n", + "plt.imshow(sino_high_snr, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Sinogram with high SNR')\n", + "plt.title(\"Sinogram with high SNR\")\n", "plt.show()" ] }, @@ -143,33 +143,33 @@ "# bottom: high_sino, diff\n", "plt.figure(figsize=(15, 10))\n", "plt.subplot(231)\n", - "plt.imshow(sino_low_snr, cmap='gray')\n", + "plt.imshow(sino_low_snr, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Sinogram with low SNR')\n", + "plt.title(\"Sinogram with low SNR\")\n", "plt.subplot(232)\n", - "plt.imshow(sino_low_snr_denoised, cmap='gray')\n", + "plt.imshow(sino_low_snr_denoised, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Denoised sinogram with low SNR')\n", + "plt.title(\"Denoised sinogram with low SNR\")\n", "plt.subplot(233)\n", "diff = sino_low_snr - sino_low_snr_denoised\n", "cval = np.quantile(np.abs(diff), 0.98)\n", - "plt.imshow(diff, cmap='seismic', vmin=-cval, vmax=cval)\n", + "plt.imshow(diff, cmap=\"seismic\", vmin=-cval, vmax=cval)\n", "plt.colorbar()\n", - "plt.title('Difference')\n", + "plt.title(\"Difference\")\n", "plt.subplot(234)\n", - "plt.imshow(sino_high_snr, cmap='gray')\n", + "plt.imshow(sino_high_snr, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Sinogram with high SNR')\n", + "plt.title(\"Sinogram with high SNR\")\n", "plt.subplot(235)\n", - "plt.imshow(sino_high_snr_denoised, cmap='gray')\n", + "plt.imshow(sino_high_snr_denoised, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Denoised sinogram with high SNR')\n", + "plt.title(\"Denoised sinogram with high SNR\")\n", "plt.subplot(236)\n", "diff = sino_high_snr - sino_high_snr_denoised\n", "cval = np.quantile(np.abs(diff), 0.98)\n", - "plt.imshow(diff, cmap='seismic', vmin=-cval, vmax=cval)\n", + "plt.imshow(diff, cmap=\"seismic\", vmin=-cval, vmax=cval)\n", "plt.colorbar()\n", - "plt.title('Difference')\n", + "plt.title(\"Difference\")\n", "plt.show()" ] }, @@ -207,33 +207,33 @@ "# bottom: high_sino, diff\n", "plt.figure(figsize=(15, 10))\n", "plt.subplot(231)\n", - "plt.imshow(sino_low_snr, cmap='gray')\n", + "plt.imshow(sino_low_snr, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Sinogram with low SNR')\n", + "plt.title(\"Sinogram with low SNR\")\n", "plt.subplot(232)\n", - "plt.imshow(sino_low_snr_denoised, cmap='gray')\n", + "plt.imshow(sino_low_snr_denoised, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Denoised sinogram with low SNR')\n", + "plt.title(\"Denoised sinogram with low SNR\")\n", "plt.subplot(233)\n", "diff = sino_low_snr - sino_low_snr_denoised\n", "cval = np.quantile(np.abs(diff), 0.98)\n", - "plt.imshow(diff, cmap='seismic', vmin=-cval, vmax=cval)\n", + "plt.imshow(diff, cmap=\"seismic\", vmin=-cval, vmax=cval)\n", "plt.colorbar()\n", - "plt.title('Difference')\n", + "plt.title(\"Difference\")\n", "plt.subplot(234)\n", - "plt.imshow(sino_high_snr, cmap='gray')\n", + "plt.imshow(sino_high_snr, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Sinogram with high SNR')\n", + "plt.title(\"Sinogram with high SNR\")\n", "plt.subplot(235)\n", - "plt.imshow(sino_high_snr_denoised, cmap='gray')\n", + "plt.imshow(sino_high_snr_denoised, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Denoised sinogram with high SNR')\n", + "plt.title(\"Denoised sinogram with high SNR\")\n", "plt.subplot(236)\n", "diff = sino_high_snr - sino_high_snr_denoised\n", "cval = np.quantile(np.abs(diff), 0.98)\n", - "plt.imshow(diff, cmap='seismic', vmin=-cval, vmax=cval)\n", + "plt.imshow(diff, cmap=\"seismic\", vmin=-cval, vmax=cval)\n", "plt.colorbar()\n", - "plt.title('Difference')\n", + "plt.title(\"Difference\")\n", "plt.show()" ] }, @@ -271,33 +271,33 @@ "# bottom: high_sino, diff\n", "plt.figure(figsize=(15, 10))\n", "plt.subplot(231)\n", - "plt.imshow(sino_low_snr, cmap='gray')\n", + "plt.imshow(sino_low_snr, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Sinogram with low SNR')\n", + "plt.title(\"Sinogram with low SNR\")\n", "plt.subplot(232)\n", - "plt.imshow(sino_low_snr_denoised, cmap='gray')\n", + "plt.imshow(sino_low_snr_denoised, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Denoised sinogram with low SNR')\n", + "plt.title(\"Denoised sinogram with low SNR\")\n", "plt.subplot(233)\n", "diff = sino_low_snr - sino_low_snr_denoised\n", "cval = np.quantile(np.abs(diff), 0.98)\n", - "plt.imshow(diff, cmap='seismic', vmin=-cval, vmax=cval)\n", + "plt.imshow(diff, cmap=\"seismic\", vmin=-cval, vmax=cval)\n", "plt.colorbar()\n", - "plt.title('Difference')\n", + "plt.title(\"Difference\")\n", "plt.subplot(234)\n", - "plt.imshow(sino_high_snr, cmap='gray')\n", + "plt.imshow(sino_high_snr, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Sinogram with high SNR')\n", + "plt.title(\"Sinogram with high SNR\")\n", "plt.subplot(235)\n", - "plt.imshow(sino_high_snr_denoised, cmap='gray')\n", + "plt.imshow(sino_high_snr_denoised, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Denoised sinogram with high SNR')\n", + "plt.title(\"Denoised sinogram with high SNR\")\n", "plt.subplot(236)\n", "diff = sino_high_snr - sino_high_snr_denoised\n", "cval = np.quantile(np.abs(diff), 0.98)\n", - "plt.imshow(diff, cmap='seismic', vmin=-cval, vmax=cval)\n", + "plt.imshow(diff, cmap=\"seismic\", vmin=-cval, vmax=cval)\n", "plt.colorbar()\n", - "plt.title('Difference')\n", + "plt.title(\"Difference\")\n", "plt.show()" ] }, diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index e567ff9..fc43a03 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -35,7 +35,7 @@ "import os\n", "\n", "# Set the GPU device ID to 0 for this notebook session\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" ] }, { @@ -65,10 +65,12 @@ "outputs": [], "source": [ "# define image size\n", - "image_size = 512 # smaller size runs faster on local machine, large number means wider image\n", + "image_size = (\n", + " 512 # smaller size runs faster on local machine, large number means wider image\n", + ")\n", "scan_step = 0.5 # deg, smaller number means taller image\n", - "detector_gain_range=(0.97, 1.03) # variation along detector width\n", - "detector_gain_error=0.01 # variation along time/rotation" + "detector_gain_range = (0.97, 1.03) # variation along detector width\n", + "detector_gain_error = 0.01 # variation along time/rotation" ] }, { @@ -92,13 +94,13 @@ "shepp_logan_2d = shepp_logan_phantom(\n", " size=image_size,\n", " contrast_factor=8,\n", - " )\n", + ")\n", "\n", "# transform to sinogram\n", "sino_org, thetas_deg = generate_sinogram(\n", " input_img=shepp_logan_2d,\n", " scan_step=scan_step,\n", - " )\n", + ")\n", "\n", "# add detector gain error\n", "sino_noisy, detector_gain = simulate_detector_gain_error(\n", @@ -141,12 +143,12 @@ ], "source": [ "fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n", - "ax[0].imshow(shepp_logan_2d, cmap='gray')\n", - "ax[0].set_title('Original Shepp-Logan Phantom')\n", - "ax[1].imshow(sino_org, cmap='gray')\n", - "ax[1].set_title('Sinogram')\n", - "ax[2].imshow(sino_noisy, cmap='gray')\n", - "ax[2].set_title('Sinogram with Detector Gain Error')\n", + "ax[0].imshow(shepp_logan_2d, cmap=\"gray\")\n", + "ax[0].set_title(\"Original Shepp-Logan Phantom\")\n", + "ax[1].imshow(sino_org, cmap=\"gray\")\n", + "ax[1].set_title(\"Sinogram\")\n", + "ax[2].imshow(sino_noisy, cmap=\"gray\")\n", + "ax[2].set_title(\"Sinogram with Detector Gain Error\")\n", "\n", "print(shepp_logan_2d.shape, sino_org.shape, sino_noisy.shape)\n", "print(sino_noisy.min(), sino_noisy.max())" @@ -436,10 +438,10 @@ "\n", "sion_bm3d_attenuated = bm3dsr.extreme_streak_attenuation(\n", " data=sino_noisy,\n", - "# extreme_streak_iterations=3,\n", - "# extreme_detect_lambda=4.0,\n", - "# extreme_detect_size=9,\n", - "# extreme_replace_size=2,\n", + " # extreme_streak_iterations=3,\n", + " # extreme_detect_lambda=4.0,\n", + " # extreme_detect_size=9,\n", + " # extreme_replace_size=2,\n", ")" ] }, @@ -575,12 +577,12 @@ ], "source": [ "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", - "axs[0].imshow(sino_noisy, cmap='gray')\n", - "axs[0].set_title('Noisy sinogram')\n", - "axs[1].imshow(sion_bm3d_attenuated, cmap='gray')\n", - "axs[1].set_title('BM3D extreme streak attenuation')\n", - "axs[2].imshow(sino_bm3d, cmap='gray')\n", - "axs[2].set_title('BM3D denoised sinogram')\n", + "axs[0].imshow(sino_noisy, cmap=\"gray\")\n", + "axs[0].set_title(\"Noisy sinogram\")\n", + "axs[1].imshow(sion_bm3d_attenuated, cmap=\"gray\")\n", + "axs[1].set_title(\"BM3D extreme streak attenuation\")\n", + "axs[2].imshow(sino_bm3d, cmap=\"gray\")\n", + "axs[2].set_title(\"BM3D denoised sinogram\")\n", "plt.show()" ] }, @@ -747,7 +749,7 @@ "\n", "val_cap = np.absolute(sino_bm3d - sino_noisy).max()\n", "\n", - "print(f'val_cap: {val_cap}')\n", + "print(f\"val_cap: {val_cap}\")\n", "\n", "# top row\n", "# - CDF\n", @@ -761,51 +763,53 @@ "# cdf_bm3dornl\n", "cdf_bm3dornl_sorted, p_bm3dornl = compute_cdf(sino_bm3dornl)\n", "# plot\n", - "axs[0, 0].plot(cdf_org_sorted, p_org, label='Original', linestyle='--', linewidth=0.5)\n", - "axs[0, 0].plot(cdf_noisy_sorted, p_noisy, label='Noisy', linestyle='--', linewidth=0.5)\n", - "axs[0, 0].plot(cdf_bm3d_sorted, p_bm3d, label='BM3D', linestyle='--', linewidth=0.5)\n", - "axs[0, 0].plot(cdf_bm3dornl_sorted, p_bm3dornl, label='BM3D-ORNL', linestyle='--', linewidth=0.5)\n", - "axs[0, 0].set_title('CDF')\n", + "axs[0, 0].plot(cdf_org_sorted, p_org, label=\"Original\", linestyle=\"--\", linewidth=0.5)\n", + "axs[0, 0].plot(cdf_noisy_sorted, p_noisy, label=\"Noisy\", linestyle=\"--\", linewidth=0.5)\n", + "axs[0, 0].plot(cdf_bm3d_sorted, p_bm3d, label=\"BM3D\", linestyle=\"--\", linewidth=0.5)\n", + "axs[0, 0].plot(\n", + " cdf_bm3dornl_sorted, p_bm3dornl, label=\"BM3D-ORNL\", linestyle=\"--\", linewidth=0.5\n", + ")\n", + "axs[0, 0].set_title(\"CDF\")\n", "axs[0, 0].legend()\n", "# - BM3D\n", - "axs[0, 1].imshow(sino_bm3d, cmap='gray')\n", - "axs[0, 1].set_title('BM3D')\n", - "axs[0, 1].axis('off')\n", + "axs[0, 1].imshow(sino_bm3d, cmap=\"gray\")\n", + "axs[0, 1].set_title(\"BM3D\")\n", + "axs[0, 1].axis(\"off\")\n", "# - BM3D-ORNL\n", - "axs[0, 2].imshow(sino_bm3dornl, cmap='gray')\n", - "axs[0, 2].set_title('BM3D-ORNL')\n", - "axs[0, 2].axis('off')\n", + "axs[0, 2].imshow(sino_bm3dornl, cmap=\"gray\")\n", + "axs[0, 2].set_title(\"BM3D-ORNL\")\n", + "axs[0, 2].axis(\"off\")\n", "\n", "# middle row\n", "# - Original\n", - "axs[1, 0].imshow(sino_org, cmap='gray')\n", - "axs[1, 0].set_title('Original')\n", - "axs[1, 0].axis('off')\n", + "axs[1, 0].imshow(sino_org, cmap=\"gray\")\n", + "axs[1, 0].set_title(\"Original\")\n", + "axs[1, 0].axis(\"off\")\n", "# - BM3D - Original\n", - "axs[1, 1].imshow(sino_bm3d - sino_org, cmap='bwr', vmin=-val_cap, vmax=val_cap)\n", - "axs[1, 1].set_title('BM3D - Original')\n", - "axs[1, 1].axis('off')\n", + "axs[1, 1].imshow(sino_bm3d - sino_org, cmap=\"bwr\", vmin=-val_cap, vmax=val_cap)\n", + "axs[1, 1].set_title(\"BM3D - Original\")\n", + "axs[1, 1].axis(\"off\")\n", "# - BM3D-ORNL - Original\n", - "axs[1, 2].imshow(sino_bm3dornl - sino_org, cmap='bwr', vmin=-val_cap, vmax=val_cap)\n", - "axs[1, 2].set_title('BM3D-ORNL - Original')\n", - "axs[1, 2].axis('off')\n", + "axs[1, 2].imshow(sino_bm3dornl - sino_org, cmap=\"bwr\", vmin=-val_cap, vmax=val_cap)\n", + "axs[1, 2].set_title(\"BM3D-ORNL - Original\")\n", + "axs[1, 2].axis(\"off\")\n", "\n", "# bottom row\n", "# - Noisy\n", - "axs[2, 0].imshow(sino_noisy, cmap='gray')\n", - "axs[2, 0].set_title('Noisy')\n", - "axs[2, 0].axis('off')\n", + "axs[2, 0].imshow(sino_noisy, cmap=\"gray\")\n", + "axs[2, 0].set_title(\"Noisy\")\n", + "axs[2, 0].axis(\"off\")\n", "# - BM3D - Noisy\n", - "axs[2, 1].imshow(sino_bm3d - sino_noisy, cmap='bwr', vmin=-val_cap, vmax=val_cap)\n", - "axs[2, 1].set_title('BM3D - Noisy')\n", - "axs[2, 1].axis('off')\n", + "axs[2, 1].imshow(sino_bm3d - sino_noisy, cmap=\"bwr\", vmin=-val_cap, vmax=val_cap)\n", + "axs[2, 1].set_title(\"BM3D - Noisy\")\n", + "axs[2, 1].axis(\"off\")\n", "# - BM3D-ORNL - Noisy\n", - "axs[2, 2].imshow(sino_bm3dornl - sino_noisy, cmap='bwr', vmin=-val_cap, vmax=val_cap)\n", - "axs[2, 2].set_title('BM3D-ORNL - Noisy')\n", - "axs[2, 2].axis('off')\n", + "axs[2, 2].imshow(sino_bm3dornl - sino_noisy, cmap=\"bwr\", vmin=-val_cap, vmax=val_cap)\n", + "axs[2, 2].set_title(\"BM3D-ORNL - Noisy\")\n", + "axs[2, 2].axis(\"off\")\n", "\n", "fig.tight_layout()\n", - "fig.savefig('bm3d_vs_bm3dornl_sino.png')\n", + "fig.savefig(\"bm3d_vs_bm3dornl_sino.png\")\n", "\n", "plt.show()" ] @@ -845,8 +849,8 @@ "\n", "kwargs = {\n", " \"center\": sino_org.shape[1] / 2,\n", - " \"algorithm\": 'fbp',\n", - " \"filter_name\": 'shepp',\n", + " \"algorithm\": \"fbp\",\n", + " \"filter_name\": \"shepp\",\n", "}\n", "\n", "#\n", @@ -911,21 +915,21 @@ "\n", "# top row\n", "# - Original\n", - "axs[0, 0].imshow(recon_org, cmap='gray')\n", - "axs[0, 0].set_title('Original')\n", - "axs[0, 0].axis('off')\n", + "axs[0, 0].imshow(recon_org, cmap=\"gray\")\n", + "axs[0, 0].set_title(\"Original\")\n", + "axs[0, 0].axis(\"off\")\n", "# - Noisy\n", - "axs[0, 1].imshow(recon_noisy, cmap='gray')\n", - "axs[0, 1].set_title('Noisy')\n", - "axs[0, 1].axis('off')\n", + "axs[0, 1].imshow(recon_noisy, cmap=\"gray\")\n", + "axs[0, 1].set_title(\"Noisy\")\n", + "axs[0, 1].axis(\"off\")\n", "# - BM3D\n", - "axs[0, 2].imshow(recon_bm3d, cmap='gray')\n", - "axs[0, 2].set_title('BM3D')\n", - "axs[0, 2].axis('off')\n", + "axs[0, 2].imshow(recon_bm3d, cmap=\"gray\")\n", + "axs[0, 2].set_title(\"BM3D\")\n", + "axs[0, 2].axis(\"off\")\n", "# - BM3D-ORNL\n", - "axs[0, 3].imshow(recon_bm3dornl, cmap='gray')\n", - "axs[0, 3].set_title('BM3D-ORNL')\n", - "axs[0, 3].axis('off')\n", + "axs[0, 3].imshow(recon_bm3dornl, cmap=\"gray\")\n", + "axs[0, 3].set_title(\"BM3D-ORNL\")\n", + "axs[0, 3].axis(\"off\")\n", "\n", "# bottom row\n", "# - CDF\n", @@ -939,24 +943,26 @@ "# cdf_bm3dornl\n", "cdf_bm3dornl_sorted, p_bm3dornl = compute_cdf(recon_bm3dornl)\n", "# plot\n", - "axs[1, 0].plot(cdf_org_sorted, p_org, label='Original', linestyle='--', linewidth=0.5)\n", - "axs[1, 0].plot(cdf_noisy_sorted, p_noisy, label='Noisy', linestyle='--', linewidth=0.5)\n", - "axs[1, 0].plot(cdf_bm3d_sorted, p_bm3d, label='BM3D', linestyle='--', linewidth=0.5)\n", - "axs[1, 0].plot(cdf_bm3dornl_sorted, p_bm3dornl, label='BM3D-ORNL', linestyle='--', linewidth=0.5)\n", - "axs[1, 0].set_title('CDF')\n", + "axs[1, 0].plot(cdf_org_sorted, p_org, label=\"Original\", linestyle=\"--\", linewidth=0.5)\n", + "axs[1, 0].plot(cdf_noisy_sorted, p_noisy, label=\"Noisy\", linestyle=\"--\", linewidth=0.5)\n", + "axs[1, 0].plot(cdf_bm3d_sorted, p_bm3d, label=\"BM3D\", linestyle=\"--\", linewidth=0.5)\n", + "axs[1, 0].plot(\n", + " cdf_bm3dornl_sorted, p_bm3dornl, label=\"BM3D-ORNL\", linestyle=\"--\", linewidth=0.5\n", + ")\n", + "axs[1, 0].set_title(\"CDF\")\n", "axs[1, 0].legend()\n", "# - Noisy - Original\n", - "axs[1, 1].imshow(recon_noisy - recon_org, cmap='bwr', vmin=-val_cap, vmax=val_cap)\n", - "axs[1, 1].set_title('Noisy - Original')\n", - "axs[1, 1].axis('off')\n", + "axs[1, 1].imshow(recon_noisy - recon_org, cmap=\"bwr\", vmin=-val_cap, vmax=val_cap)\n", + "axs[1, 1].set_title(\"Noisy - Original\")\n", + "axs[1, 1].axis(\"off\")\n", "# - BM3D - Original\n", - "axs[1, 2].imshow(recon_bm3d - recon_org, cmap='bwr', vmin=-val_cap, vmax=val_cap)\n", - "axs[1, 2].set_title('BM3D - Original')\n", - "axs[1, 2].axis('off')\n", + "axs[1, 2].imshow(recon_bm3d - recon_org, cmap=\"bwr\", vmin=-val_cap, vmax=val_cap)\n", + "axs[1, 2].set_title(\"BM3D - Original\")\n", + "axs[1, 2].axis(\"off\")\n", "# - BM3D-ORNL - Original\n", - "axs[1, 3].imshow(recon_bm3dornl - recon_org, cmap='bwr', vmin=-val_cap, vmax=val_cap)\n", - "axs[1, 3].set_title('BM3D-ORNL - Original')\n", - "axs[1, 3].axis('off')\n" + "axs[1, 3].imshow(recon_bm3dornl - recon_org, cmap=\"bwr\", vmin=-val_cap, vmax=val_cap)\n", + "axs[1, 3].set_title(\"BM3D-ORNL - Original\")\n", + "axs[1, 3].axis(\"off\")" ] }, { diff --git a/notebooks/example_ms.ipynb b/notebooks/example_ms.ipynb index 049c52b..65479ef 100644 --- a/notebooks/example_ms.ipynb +++ b/notebooks/example_ms.ipynb @@ -23,15 +23,15 @@ "from bm3dornl.bm3d import bm3d_ring_artifact_removal_ms\n", "\n", "# Set the GPU device ID to 0 for this notebook session\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "\n", "this_dir = os.path.abspath(\"\")\n", - "data_dir = os.path.join(this_dir, '../tests/bm3dornl-data')\n", - "datafile = os.path.join(data_dir, 'tomostack_small.h5')\n", + "data_dir = os.path.join(this_dir, \"../tests/bm3dornl-data\")\n", + "datafile = os.path.join(data_dir, \"tomostack_small.h5\")\n", "\n", "# Load the data and select a noisy sinogram\n", - "with h5py.File(datafile, 'r') as f:\n", - " tomo_stack_noisy = f['noisy_tomostack'][:]\n", + "with h5py.File(datafile, \"r\") as f:\n", + " tomo_stack_noisy = f[\"noisy_tomostack\"][:]\n", "\n", "# Select a sinogram with low SNR\n", "sino_low_snr = tomo_stack_noisy[:, 10, :]\n", @@ -42,13 +42,13 @@ "# Plot the sinograms\n", "plt.figure(figsize=(15, 5))\n", "plt.subplot(121)\n", - "plt.imshow(sino_low_snr, cmap='gray')\n", + "plt.imshow(sino_low_snr, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Sinogram with low SNR')\n", + "plt.title(\"Sinogram with low SNR\")\n", "plt.subplot(122)\n", - "plt.imshow(sino_high_snr, cmap='gray')\n", + "plt.imshow(sino_high_snr, cmap=\"gray\")\n", "plt.colorbar()\n", - "plt.title('Sinogram with high SNR')\n", + "plt.title(\"Sinogram with high SNR\")\n", "plt.show()" ] }, diff --git a/notebooks/stepbystep_Makinen2020.ipynb b/notebooks/stepbystep_Makinen2020.ipynb index fb12454..b46d133 100644 --- a/notebooks/stepbystep_Makinen2020.ipynb +++ b/notebooks/stepbystep_Makinen2020.ipynb @@ -37,12 +37,12 @@ "outputs": [], "source": [ "this_dir = os.path.abspath(\"\")\n", - "data_dir = os.path.join(this_dir, '../tests/bm3dornl-data')\n", - "datafile = os.path.join(data_dir, 'tomostack_small.h5')\n", + "data_dir = os.path.join(this_dir, \"../tests/bm3dornl-data\")\n", + "datafile = os.path.join(data_dir, \"tomostack_small.h5\")\n", "\n", - "with h5py.File(datafile, 'r') as f:\n", - " tomo_noisy = f['noisy_tomostack'][:]\n", - " tomo_bm3d_ref = f['clean_tomostack'][:]" + "with h5py.File(datafile, \"r\") as f:\n", + " tomo_noisy = f[\"noisy_tomostack\"][:]\n", + " tomo_bm3d_ref = f[\"clean_tomostack\"][:]" ] }, { @@ -91,10 +91,10 @@ "\n", "# visualize side by side\n", "fig, ax = plt.subplots(1, 2)\n", - "ax[0].imshow(sino_noisy, cmap='gray')\n", - "ax[0].set_title('Noisy sinogram')\n", - "ax[1].imshow(sino_bm3d_ref, cmap='gray')\n", - "ax[1].set_title('BM3D ref sinogram')\n", + "ax[0].imshow(sino_noisy, cmap=\"gray\")\n", + "ax[0].set_title(\"Noisy sinogram\")\n", + "ax[1].imshow(sino_bm3d_ref, cmap=\"gray\")\n", + "ax[1].set_title(\"BM3D ref sinogram\")\n", "plt.show()" ] }, @@ -197,7 +197,9 @@ " background_threshold=background_threshold,\n", ")\n", "# retrieve all patches (need the actual signals)\n", - "signal_patches = np.array([get_patch_numba(z, pos, patch_size) for pos in patch_positions])\n", + "signal_patches = np.array(\n", + " [get_patch_numba(z, pos, patch_size) for pos in patch_positions]\n", + ")\n", "# transform the patches to the frequency domain (use fft here for simplicity)\n", "signal_patches_fft = fft_transform(signal_patches)\n", "# estimate the noise variance for the transformed patch\n", @@ -237,12 +239,12 @@ "# quick visualization\n", "plt.figure(figsize=(12, 5))\n", "plt.subplot(1, 2, 1)\n", - "plt.imshow(z, cmap='gray')\n", - "plt.title('z')\n", + "plt.imshow(z, cmap=\"gray\")\n", + "plt.title(\"z\")\n", "plt.colorbar()\n", "plt.subplot(1, 2, 2)\n", - "plt.imshow(yhat_ht, cmap='gray')\n", - "plt.title('$\\hat{y}_{ht}$')\n", + "plt.imshow(yhat_ht, cmap=\"gray\")\n", + "plt.title(\"$\\hat{y}_{ht}$\")\n", "plt.colorbar()\n", "plt.show()" ] @@ -287,12 +289,12 @@ "# quick visualization\n", "plt.figure(figsize=(12, 5))\n", "plt.subplot(1, 2, 1)\n", - "plt.imshow(z, cmap='gray')\n", - "plt.title('z')\n", + "plt.imshow(z, cmap=\"gray\")\n", + "plt.title(\"z\")\n", "plt.colorbar()\n", "plt.subplot(1, 2, 2)\n", - "plt.imshow(z_gft_ht, cmap='gray')\n", - "plt.title('$z^{GFT}_{ht}$')\n", + "plt.imshow(z_gft_ht, cmap=\"gray\")\n", + "plt.title(\"$z^{GFT}_{ht}$\")\n", "plt.colorbar()\n", "plt.show()" ] @@ -315,7 +317,9 @@ "outputs": [], "source": [ "# retrieve patches from z_gft_ht\n", - "signal_patches = np.array([get_patch_numba(z_gft_ht, pos, patch_size) for pos in patch_positions])\n", + "signal_patches = np.array(\n", + " [get_patch_numba(z_gft_ht, pos, patch_size) for pos in patch_positions]\n", + ")\n", "# compute the transformed patches\n", "signal_patches_fft = fft_transform(signal_patches)\n", "# estimate the noise variance for the transformed patch\n", @@ -355,12 +359,12 @@ "# quick visualization\n", "plt.figure(figsize=(12, 5))\n", "plt.subplot(1, 2, 1)\n", - "plt.imshow(z, cmap='gray')\n", - "plt.title('z')\n", + "plt.imshow(z, cmap=\"gray\")\n", + "plt.title(\"z\")\n", "plt.colorbar()\n", "plt.subplot(1, 2, 2)\n", - "plt.imshow(yhat_ht_gft, cmap='gray')\n", - "plt.title('$\\hat{y}^{GFT}_{ht}$')\n", + "plt.imshow(yhat_ht_gft, cmap=\"gray\")\n", + "plt.title(\"$\\hat{y}^{GFT}_{ht}$\")\n", "plt.colorbar()\n", "plt.show()" ] @@ -423,12 +427,12 @@ "# quick visualization\n", "plt.figure(figsize=(12, 5))\n", "plt.subplot(1, 2, 1)\n", - "plt.imshow(z, cmap='gray')\n", - "plt.title('z')\n", + "plt.imshow(z, cmap=\"gray\")\n", + "plt.title(\"z\")\n", "plt.colorbar()\n", "plt.subplot(1, 2, 2)\n", - "plt.imshow(yhat_wie, cmap='gray')\n", - "plt.title('$\\hat{y}_{wie}$')\n", + "plt.imshow(yhat_wie, cmap=\"gray\")\n", + "plt.title(\"$\\hat{y}_{wie}$\")\n", "plt.colorbar()\n", "plt.show()" ] @@ -473,12 +477,12 @@ "# quick visualization\n", "plt.figure(figsize=(12, 5))\n", "plt.subplot(1, 2, 1)\n", - "plt.imshow(z, cmap='gray')\n", - "plt.title('z')\n", + "plt.imshow(z, cmap=\"gray\")\n", + "plt.title(\"z\")\n", "plt.colorbar()\n", "plt.subplot(1, 2, 2)\n", - "plt.imshow(z_gft_wie, cmap='gray')\n", - "plt.title('$z^{GFT}_{wie}$')\n", + "plt.imshow(z_gft_wie, cmap=\"gray\")\n", + "plt.title(\"$z^{GFT}_{wie}$\")\n", "plt.colorbar()\n", "plt.show()" ] @@ -501,7 +505,9 @@ "outputs": [], "source": [ "# first thing first, we need to compute PSD for the new noisy sinogram\n", - "signal_patches = np.array([get_patch_numba(z_gft_wie, pos, patch_size) for pos in patch_positions])\n", + "signal_patches = np.array(\n", + " [get_patch_numba(z_gft_wie, pos, patch_size) for pos in patch_positions]\n", + ")\n", "signal_patches_fft = fft_transform(signal_patches)\n", "phi_fft_gft_wie = get_exact_noise_variance(signal_patches_fft)\n", "# commence wiener filtering\n", @@ -539,12 +545,12 @@ "# quick visualization\n", "plt.figure(figsize=(12, 5))\n", "plt.subplot(1, 2, 1)\n", - "plt.imshow(z, cmap='gray')\n", - "plt.title('z')\n", + "plt.imshow(z, cmap=\"gray\")\n", + "plt.title(\"z\")\n", "plt.colorbar()\n", "plt.subplot(1, 2, 2)\n", - "plt.imshow(yhat_final, cmap='gray')\n", - "plt.title('$\\hat{y}_{final}$')\n", + "plt.imshow(yhat_final, cmap=\"gray\")\n", + "plt.title(\"$\\hat{y}_{final}$\")\n", "plt.colorbar()\n", "plt.show()" ] @@ -563,7 +569,9 @@ "outputs": [], "source": [ "# rescale the final denoised sinogram to [0, 1]\n", - "yhat_final = (yhat_final - np.min(yhat_final)) / (np.max(yhat_final) - np.min(yhat_final))\n", + "yhat_final = (yhat_final - np.min(yhat_final)) / (\n", + " np.max(yhat_final) - np.min(yhat_final)\n", + ")\n", "# rescale the final denoised sinogram to the original dynamic range\n", "sino_denoised = yhat_final * (original_max - original_min) + original_min" ] @@ -588,18 +596,18 @@ "# visualize side by side\n", "plt.figure(figsize=(16, 5))\n", "plt.subplot(1, 3, 1)\n", - "plt.imshow(sino_noisy, cmap='gray')\n", - "plt.title('Noisy sinogram')\n", + "plt.imshow(sino_noisy, cmap=\"gray\")\n", + "plt.title(\"Noisy sinogram\")\n", "plt.colorbar()\n", "plt.subplot(1, 3, 2)\n", - "plt.imshow(sino_denoised, cmap='gray')\n", - "plt.title('bm3dornl sinogram')\n", + "plt.imshow(sino_denoised, cmap=\"gray\")\n", + "plt.title(\"bm3dornl sinogram\")\n", "plt.colorbar()\n", "plt.subplot(1, 3, 3)\n", "diff = sino_denoised - sino_noisy\n", "cval = np.max(np.abs(diff))\n", - "plt.imshow(diff, cmap='seismic', vmin=-cval, vmax=cval)\n", - "plt.title('Difference')\n", + "plt.imshow(diff, cmap=\"seismic\", vmin=-cval, vmax=cval)\n", + "plt.title(\"Difference\")\n", "plt.colorbar()\n", "plt.show()" ] diff --git a/study/paramter_tuning/view_rst.ipynb b/study/paramter_tuning/view_rst.ipynb index cac859e..8d4e938 100644 --- a/study/paramter_tuning/view_rst.ipynb +++ b/study/paramter_tuning/view_rst.ipynb @@ -36,17 +36,17 @@ "import os\n", "\n", "# Set the GPU device ID to 0 for this notebook session\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "\n", "import h5py\n", "\n", "this_dir = os.path.abspath(\"\")\n", - "data_dir = os.path.join(this_dir, '../../tests/bm3dornl-data')\n", - "datafile = os.path.join(data_dir, 'tomostack_small.h5')\n", + "data_dir = os.path.join(this_dir, \"../../tests/bm3dornl-data\")\n", + "datafile = os.path.join(data_dir, \"tomostack_small.h5\")\n", "\n", "# Load the data and select a noisy sinogram\n", - "with h5py.File(datafile, 'r') as f:\n", - " tomo_stack_noisy = f['noisy_tomostack'][:]\n", + "with h5py.File(datafile, \"r\") as f:\n", + " tomo_stack_noisy = f[\"noisy_tomostack\"][:]\n", " tomo_stack_clean = f[\"clean_tomostack\"][:]" ] }, @@ -279,9 +279,9 @@ "source": [ "import bm3d_streak_removal as bm3dsr\n", "from bm3dornl.bm3d import (\n", - " estimate_noise_free_sinogram,\n", - " bm3d_ring_artifact_removal,\n", - " bm3d_ring_artifact_removal_ms,\n", + " estimate_noise_free_sinogram,\n", + " bm3d_ring_artifact_removal,\n", + " bm3d_ring_artifact_removal_ms,\n", ")" ] }, @@ -291,7 +291,7 @@ "metadata": {}, "outputs": [], "source": [ - "def view_slice(idx: int, ms_iter: int=3):\n", + "def view_slice(idx: int, ms_iter: int = 3):\n", " # input\n", " sino_noisy = tomo_stack_noisy[:, idx, :]\n", " # estimate\n", @@ -334,33 +334,33 @@ " # row_0: input, sino_bm3d, sino_bm3d_ms, sino_bm3d_ms - input, sino_bm3d_ms - sino_bm3d\n", " # row_1: estimate, sino_bm3dornl, sino_bm3dornl_ms, sino_bm3dornl_ms - input, sino_bm3dornl_ms - sino_bm3dornl\n", " fig, axs = plt.subplots(2, 5, figsize=(20, 10))\n", - " axs[0, 0].imshow(sino_noisy, cmap='gray')\n", + " axs[0, 0].imshow(sino_noisy, cmap=\"gray\")\n", " axs[0, 0].set_title(\"Input\")\n", - " axs[0, 1].imshow(sino_bm3d, cmap='gray')\n", + " axs[0, 1].imshow(sino_bm3d, cmap=\"gray\")\n", " axs[0, 1].set_title(\"BM3D\")\n", - " axs[0, 2].imshow(sino_bm3d_ms, cmap='gray')\n", + " axs[0, 2].imshow(sino_bm3d_ms, cmap=\"gray\")\n", " axs[0, 2].set_title(\"BM3D MS\")\n", " diff = sino_bm3d_ms - sino_noisy\n", " cval = np.max(np.abs(diff))\n", - " axs[0, 3].imshow(diff, cmap='bwr', vmin=-cval, vmax=cval)\n", + " axs[0, 3].imshow(diff, cmap=\"bwr\", vmin=-cval, vmax=cval)\n", " axs[0, 3].set_title(f\"BM3D MS - Input: cval={cval}\")\n", " diff = sino_bm3d_ms - sino_bm3d\n", " cval = np.max(np.abs(diff))\n", - " axs[0, 4].imshow(diff, cmap='bwr', vmin=-cval, vmax=cval)\n", + " axs[0, 4].imshow(diff, cmap=\"bwr\", vmin=-cval, vmax=cval)\n", " axs[0, 4].set_title(f\"BM3D MS - BM3D: cval={cval}\")\n", - " axs[1, 0].imshow(sino_estimate, cmap='gray')\n", + " axs[1, 0].imshow(sino_estimate, cmap=\"gray\")\n", " axs[1, 0].set_title(\"Estimate\")\n", - " axs[1, 1].imshow(sino_bm3dornl, cmap='gray')\n", + " axs[1, 1].imshow(sino_bm3dornl, cmap=\"gray\")\n", " axs[1, 1].set_title(\"BM3D ORNL\")\n", - " axs[1, 2].imshow(sino_bm3dornl_ms, cmap='gray')\n", + " axs[1, 2].imshow(sino_bm3dornl_ms, cmap=\"gray\")\n", " axs[1, 2].set_title(\"BM3D ORNL MS\")\n", " diff = sino_bm3dornl_ms - sino_noisy\n", " cval = np.max(np.abs(diff))\n", - " axs[1, 3].imshow(diff, cmap='bwr', vmin=-cval, vmax=cval)\n", + " axs[1, 3].imshow(diff, cmap=\"bwr\", vmin=-cval, vmax=cval)\n", " axs[1, 3].set_title(f\"BM3D ORNL MS - Input: cval={cval}\")\n", " diff = sino_bm3dornl_ms - sino_bm3dornl\n", " cval = np.max(np.abs(diff))\n", - " axs[1, 4].imshow(diff, cmap='bwr', vmin=-cval, vmax=cval)\n", + " axs[1, 4].imshow(diff, cmap=\"bwr\", vmin=-cval, vmax=cval)\n", " axs[1, 4].set_title(f\"BM3D ORNL MS - BM3D ORNL: cval={cval}\")\n", " return fig, axs" ] @@ -821,7 +821,7 @@ "metadata": {}, "outputs": [], "source": [ - "def view_slice(idx: int, ms_iter: int=3):\n", + "def view_slice(idx: int, ms_iter: int = 3):\n", " # input\n", " sino_noisy = tomo_stack_noisy[:, idx, :]\n", " # estimate\n", @@ -864,33 +864,33 @@ " # row_0: input, sino_bm3d, sino_bm3d_ms, sino_bm3d_ms - input, sino_bm3d_ms - sino_bm3d\n", " # row_1: estimate, sino_bm3dornl, sino_bm3dornl_ms, sino_bm3dornl_ms - input, sino_bm3dornl_ms - sino_bm3dornl\n", " fig, axs = plt.subplots(2, 5, figsize=(20, 10))\n", - " axs[0, 0].imshow(sino_noisy, cmap='gray')\n", + " axs[0, 0].imshow(sino_noisy, cmap=\"gray\")\n", " axs[0, 0].set_title(\"Input\")\n", - " axs[0, 1].imshow(sino_bm3d, cmap='gray')\n", + " axs[0, 1].imshow(sino_bm3d, cmap=\"gray\")\n", " axs[0, 1].set_title(\"BM3D\")\n", - " axs[0, 2].imshow(sino_bm3d_ms, cmap='gray')\n", + " axs[0, 2].imshow(sino_bm3d_ms, cmap=\"gray\")\n", " axs[0, 2].set_title(\"BM3D MS\")\n", " diff = sino_bm3d_ms - sino_noisy\n", " cval = np.max(np.abs(diff))\n", - " axs[0, 3].imshow(diff, cmap='bwr', vmin=-cval, vmax=cval)\n", + " axs[0, 3].imshow(diff, cmap=\"bwr\", vmin=-cval, vmax=cval)\n", " axs[0, 3].set_title(f\"BM3D MS - Input: cval={cval}\")\n", " diff = sino_bm3d_ms - sino_bm3d\n", " cval = np.max(np.abs(diff))\n", - " axs[0, 4].imshow(diff, cmap='bwr', vmin=-cval, vmax=cval)\n", + " axs[0, 4].imshow(diff, cmap=\"bwr\", vmin=-cval, vmax=cval)\n", " axs[0, 4].set_title(f\"BM3D MS - BM3D: cval={cval}\")\n", - " axs[1, 0].imshow(sino_estimate, cmap='gray')\n", + " axs[1, 0].imshow(sino_estimate, cmap=\"gray\")\n", " axs[1, 0].set_title(\"Estimate\")\n", - " axs[1, 1].imshow(sino_bm3dornl, cmap='gray')\n", + " axs[1, 1].imshow(sino_bm3dornl, cmap=\"gray\")\n", " axs[1, 1].set_title(\"BM3D ORNL\")\n", - " axs[1, 2].imshow(sino_bm3dornl_ms, cmap='gray')\n", + " axs[1, 2].imshow(sino_bm3dornl_ms, cmap=\"gray\")\n", " axs[1, 2].set_title(\"BM3D ORNL MS\")\n", " diff = sino_bm3dornl_ms - sino_noisy\n", " cval = np.max(np.abs(diff))\n", - " axs[1, 3].imshow(diff, cmap='bwr', vmin=-cval, vmax=cval)\n", + " axs[1, 3].imshow(diff, cmap=\"bwr\", vmin=-cval, vmax=cval)\n", " axs[1, 3].set_title(f\"BM3D ORNL MS - Input: cval={cval}\")\n", " diff = sino_bm3dornl_ms - sino_bm3dornl\n", " cval = np.max(np.abs(diff))\n", - " axs[1, 4].imshow(diff, cmap='bwr', vmin=-cval, vmax=cval)\n", + " axs[1, 4].imshow(diff, cmap=\"bwr\", vmin=-cval, vmax=cval)\n", " axs[1, 4].set_title(f\"BM3D ORNL MS - BM3D ORNL: cval={cval}\")\n", " return fig, axs" ]