From 1ece71e2d4d3ce5c96af8756bcb8a4e32b3f8a82 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Sun, 15 Dec 2019 20:43:50 +0100 Subject: [PATCH] Add files via upload --- main_test_dncnn.py | 202 ++++++++++++++++++++++++++++ main_test_dpsr.py | 214 +++++++++++++++++++++++++++++ main_test_fdncnn.py | 205 ++++++++++++++++++++++++++++ main_test_ffdnet.py | 198 +++++++++++++++++++++++++++ main_test_imdn.py | 212 +++++++++++++++++++++++++++++ main_test_msrresnet.py | 213 +++++++++++++++++++++++++++++ main_test_rrdb.py | 205 ++++++++++++++++++++++++++++ main_test_srmd.py | 233 ++++++++++++++++++++++++++++++++ main_train_dncnn.py | 250 ++++++++++++++++++++++++++++++++++ main_train_dpsr.py | 242 +++++++++++++++++++++++++++++++++ main_train_fdncnn.py | 247 ++++++++++++++++++++++++++++++++++ main_train_ffdnet.py | 245 +++++++++++++++++++++++++++++++++ main_train_imdn.py | 237 ++++++++++++++++++++++++++++++++ main_train_msrresnet_gan.py | 224 ++++++++++++++++++++++++++++++ main_train_msrresnet_psnr.py | 220 ++++++++++++++++++++++++++++++ main_train_rrdb_psnr.py | 229 +++++++++++++++++++++++++++++++ main_train_srmd.py | 254 +++++++++++++++++++++++++++++++++++ 17 files changed, 3830 insertions(+) create mode 100644 main_test_dncnn.py create mode 100644 main_test_dpsr.py create mode 100644 main_test_fdncnn.py create mode 100644 main_test_ffdnet.py create mode 100644 main_test_imdn.py create mode 100644 main_test_msrresnet.py create mode 100644 main_test_rrdb.py create mode 100644 main_test_srmd.py create mode 100644 main_train_dncnn.py create mode 100644 main_train_dpsr.py create mode 100644 main_train_fdncnn.py create mode 100644 main_train_ffdnet.py create mode 100644 main_train_imdn.py create mode 100644 main_train_msrresnet_gan.py create mode 100644 main_train_msrresnet_psnr.py create mode 100644 main_train_rrdb_psnr.py create mode 100644 main_train_srmd.py diff --git a/main_test_dncnn.py b/main_test_dncnn.py new file mode 100644 index 00000000..6a3aa35e --- /dev/null +++ b/main_test_dncnn.py @@ -0,0 +1,202 @@ +import os.path +import logging + +import numpy as np +from datetime import datetime +from collections import OrderedDict +from scipy.io import loadmat + +import torch + +from utils import utils_logger +from utils import utils_model +from utils import utils_image as util + + +''' +Spyder (Python 3.6) +PyTorch 1.1.0 +Windows 10 or Linux + +Kai Zhang (cskaizhang@gmail.com) +github: https://github.com/cszn/KAIR + https://github.com/cszn/DnCNN + +@article{zhang2017beyond, + title={Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising}, + author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei}, + journal={IEEE Transactions on Image Processing}, + volume={26}, + number={7}, + pages={3142--3155}, + year={2017}, + publisher={IEEE} +} + +% If you have any question, please feel free to contact with me. +% Kai Zhang (e-mail: cskaizhang@gmail.com; github: https://github.com/cszn) + +by Kai Zhang (12/Dec./2019) +''' + +""" +# -------------------------------------------- +|--model_zoo # model_zoo + |--dncnn_15 # model_name + |--dncnn_25 + |--dncnn_50 + |--dncnn_gray_blind + |--dncnn_color_blind + |--dncnn3 +|--testset # testsets + |--set12 # testset_name + |--bsd68 + |--cbsd68 +|--results # results + |--set12_dncnn_15 # result_name = testset_name + '_' + model_name + |--set12_dncnn_25 + |--bsd68_dncnn_15 +# -------------------------------------------- +""" + + +def main(): + + # ---------------------------------------- + # Preparation + # ---------------------------------------- + + noise_level_img = 25 # noise level for noisy image + model_name = 'dncnn_25' # 'dncnn_15' | 'dncnn_25' | 'dncnn_50' | 'dncnn_gray_blind' | 'dncnn_color_blind' | 'dncnn3' + testset_name = 'bsd68' # test set, 'bsd68' | 'set12' + need_degradation = True # default: True + x8 = False # default: False, x8 to boost performance + show_img = False # default: False + + + + + task_current = 'dn' # 'dn' for denoising | 'sr' for super-resolution + sf = 1 # unused for denoising + if 'color' in model_name: + n_channels = 3 # fixed, 1 for grayscale image, 3 for color image + else: + n_channels = 1 # fixed for grayscale image + if model_name in ['dncnn_gray_blind', 'dncnn_color_blind', 'dncnn3']: + nb = 20 # fixed + else: + nb = 17 # fixed + model_pool = 'model_zoo' # fixed + testsets = 'testsets' # fixed + results = 'results' # fixed + result_name = testset_name + '_' + model_name # fixed + border = sf if task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM + model_path = os.path.join(model_pool, model_name+'.pth') + + # ---------------------------------------- + # L_path, E_path, H_path + # ---------------------------------------- + + L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images + H_path = L_path # H_path, for High-quality images + E_path = os.path.join(results, result_name) # E_path, for Estimated images + util.mkdir(E_path) + + if H_path == L_path: + need_degradation = True + logger_name = result_name + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + + need_H = True if H_path is not None else False + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_dncnn import DnCNN as net + model = net(in_nc=n_channels, out_nc=n_channels, nc=64, nb=nb, act_mode='R') + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + number_parameters = sum(map(lambda x: x.numel(), model.parameters())) + logger.info('Params number: {}'.format(number_parameters)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + + logger.info('model_name:{}, image sigma:{}'.format(model_name, noise_level_img)) + logger.info(L_path) + L_paths = util.get_image_paths(L_path) + H_paths = util.get_image_paths(H_path) if need_H else None + + for idx, img in enumerate(L_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + + img_name, ext = os.path.splitext(os.path.basename(img)) + # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + img_L = util.uint2single(img_L) + + if need_degradation: # degradation process + np.random.seed(seed=0) # for reproducibility + img_L += np.random.normal(0, noise_level_img/255., img_L.shape) + + util.imshow(util.single2uint(img_L), title='Noisy image with noise level {}'.format(noise_level_img)) if show_img else None + + img_L = util.single2tensor4(img_L) + img_L = img_L.to(device) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + if not x8: + img_E = model(img_L) + else: + img_E = utils_model.test_mode(model, img_L, mode=3) + + img_E = util.tensor2uint(img_E) + + if need_H: + + # -------------------------------- + # (3) img_H + # -------------------------------- + + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels) + img_H = img_H.squeeze() + + # -------------------------------- + # PSNR and SSIM + # -------------------------------- + + psnr = util.calculate_psnr(img_E, img_H, border=border) + ssim = util.calculate_ssim(img_E, img_H, border=border) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None + + # ------------------------------------ + # save results + # ------------------------------------ + + util.imsave(img_E, os.path.join(E_path, img_name+ext)) + + if need_H: + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + logger.info('Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim)) + +if __name__ == '__main__': + + main() diff --git a/main_test_dpsr.py b/main_test_dpsr.py new file mode 100644 index 00000000..15c106bc --- /dev/null +++ b/main_test_dpsr.py @@ -0,0 +1,214 @@ +import os.path +import logging +import re + +import numpy as np +from collections import OrderedDict + +import torch + +from utils import utils_logger +from utils import utils_image as util +from utils import utils_model + + +''' +Spyder (Python 3.6) +PyTorch 1.1.0 +Windows 10 or Linux + +Kai Zhang (cskaizhang@gmail.com) +github: https://github.com/cszn/KAIR + https://github.com/cszn/DPSR + +@inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} +} + +% If you have any question, please feel free to contact with me. +% Kai Zhang (e-mail: cskaizhang@gmail.com; github: https://github.com/cszn) + +by Kai Zhang (12/Dec./2019) +''' + +""" +# -------------------------------------------- +testing code for the super-resolver prior of DPSR +# -------------------------------------------- +|--model_zoo # model_zoo + |--dpsr_x2 # model_name, optimized for PSNR + |--dpsr_x3 + |--dpsr_x4 + |--dpsr_x4_gan # model_name, optimized for perceptual quality +|--testset # testsets + |--set5 # testset_name + |--srbsd68 +|--results # results + |--set5_dpsr_x2 # result_name = testset_name + '_' + model_name + |--set5_dpsr_x3 + |--set5_dpsr_x4 + |--set5_dpsr_x4_gan + |--srbsd68_dpsr_x4_gan +# -------------------------------------------- +""" + + +def main(): + + # ---------------------------------------- + # Preparation + # ---------------------------------------- + + noise_level_img = 0 # default: 0, noise level for LR image + noise_level_model = noise_level_img # noise level for model + model_name = 'dpsr_x4_gan' # 'dpsr_x2' | 'dpsr_x3' | 'dpsr_x4' | 'dpsr_x4_gan' + testset_name = 'set5' # test set, 'set5' | 'srbsd68' + need_degradation = True # default: True + x8 = False # default: False, x8 to boost performance + sf = [int(s) for s in re.findall(r'\d+', model_name)][0] # scale factor + show_img = False # default: False + + + + task_current = 'sr' # 'dn' for denoising | 'sr' for super-resolution + n_channels = 3 # fixed + nc = 96 # fixed, number of channels + nb = 16 # fixed, number of conv layers + model_pool = 'model_zoo' # fixed + testsets = 'testsets' # fixed + results = 'results' # fixed + result_name = testset_name + '_' + model_name + border = sf if task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM + model_path = os.path.join(model_pool, model_name+'.pth') + + # ---------------------------------------- + # L_path, E_path, H_path + # ---------------------------------------- + + L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images + H_path = L_path # H_path, for High-quality images + E_path = os.path.join(results, result_name) # E_path, for Estimated images + util.mkdir(E_path) + + if H_path == L_path: + need_degradation = True + logger_name = result_name + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + + need_H = True if H_path is not None else False + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_dpsr import MSRResNet_prior as net + model = net(in_nc=n_channels+1, out_nc=n_channels, nc=nc, nb=nb, upscale=sf, act_mode='R', upsample_mode='pixelshuffle') + model.load_state_dict(torch.load(model_path), strict=False) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + number_parameters = sum(map(lambda x: x.numel(), model.parameters())) + logger.info('Params number: {}'.format(number_parameters)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnr_y'] = [] + test_results['ssim_y'] = [] + + logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(model_name, noise_level_img, noise_level_model)) + logger.info(L_path) + L_paths = util.get_image_paths(L_path) + H_paths = util.get_image_paths(H_path) if need_H else None + + for idx, img in enumerate(L_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + + img_name, ext = os.path.splitext(os.path.basename(img)) + # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + img_L = util.uint2single(img_L) + + # degradation process, bicubic downsampling + Gaussian noise + if need_degradation: + img_L = util.modcrop(img_L, sf) + img_L = util.imresize_np(img_L, 1/sf) + np.random.seed(seed=0) # for reproducibility + img_L += np.random.normal(0, noise_level_img/255., img_L.shape) + + util.imshow(util.single2uint(img_L), title='LR image with noise level {}'.format(noise_level_img)) if show_img else None + + img_L = util.single2tensor4(img_L) + noise_level_map = torch.full((1, 1, img_L.size(2), img_L.size(3)), noise_level_model/255.).type_as(img_L) + img_L = torch.cat((img_L, noise_level_map), dim=1) + img_L = img_L.to(device) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + if not x8: + img_E = model(img_L) + else: + img_E = utils_model.test_mode(model, img_L, mode=3, sf=sf) + + img_E = util.tensor2uint(img_E) + + if need_H: + + # -------------------------------- + # (3) img_H + # -------------------------------- + + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels) + img_H = img_H.squeeze() + img_H = util.modcrop(img_H, sf) + + # -------------------------------- + # PSNR and SSIM + # -------------------------------- + + psnr = util.calculate_psnr(img_E, img_H, border=border) + ssim = util.calculate_ssim(img_E, img_H, border=border) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None + + if np.ndim(img_H) == 3: # RGB image + img_E_y = util.rgb2ycbcr(img_E, only_y=True) + img_H_y = util.rgb2ycbcr(img_H, only_y=True) + psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border) + ssim_y = util.calculate_ssim(img_E_y, img_H_y, border=border) + test_results['psnr_y'].append(psnr_y) + test_results['ssim_y'].append(ssim_y) + + # ------------------------------------ + # save results + # ------------------------------------ + + util.imsave(img_E, os.path.join(E_path, img_name+'.png')) + + if need_H: + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + logger.info('Average PSNR/SSIM(RGB) - {} - x{} --PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, sf, ave_psnr, ave_ssim)) + if np.ndim(img_H) == 3: + ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) + ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) + logger.info('Average PSNR/SSIM( Y ) - {} - x{} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, sf, ave_psnr_y, ave_ssim_y)) + +if __name__ == '__main__': + + main() diff --git a/main_test_fdncnn.py b/main_test_fdncnn.py new file mode 100644 index 00000000..57cb262a --- /dev/null +++ b/main_test_fdncnn.py @@ -0,0 +1,205 @@ +import os.path +import logging + +import numpy as np +from collections import OrderedDict +from scipy.io import loadmat + +import torch + +from utils import utils_logger +from utils import utils_model +from utils import utils_image as util + + +''' +Spyder (Python 3.6) +PyTorch 1.1.0 +Windows 10 or Linux + +Kai Zhang (cskaizhang@gmail.com) +github: https://github.com/cszn/KAIR + https://github.com/cszn/DnCNN + https://github.com/cszn/FFDNet + +@article{zhang2018ffdnet, + title={FFDNet: Toward a fast and flexible solution for CNN-based image denoising}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + journal={IEEE Transactions on Image Processing}, + volume={27}, + number={9}, + pages={4608--4622}, + year={2018}, + publisher={IEEE} +} + +% If you have any question, please feel free to contact with me. +% Kai Zhang (e-mail: cskaizhang@gmail.com; github: https://github.com/cszn) + +by Kai Zhang (12/Dec./2019) +''' + +""" +# -------------------------------------------- +|--model_zoo # model_zoo + |--fdncnn_color # model_name, for color images + |--fdncnn_gray + |--fdncnn_color_clip # for clipped uint8 color images + |--fdncnn_gray_clip +|--testset # testsets + |--set12 # testset_name + |--bsd68 + |--cbsd68 +|--results # results + |--set12_fdncnn_color # result_name = testset_name + '_' + model_name + |--set12_fdncnn_gray + |--cbsd68_fdncnn_color_clip +# -------------------------------------------- +""" + + +def main(): + + # ---------------------------------------- + # Preparation + # ---------------------------------------- + + noise_level_img = 15 # noise level for noisy image + noise_level_model = noise_level_img # noise level for model + model_name = 'fdncnn_gray' # 'fdncnn_gray' | 'fdncnn_color' | 'fdncnn_color_clip' | 'fdncnn_gray_clip' + testset_name = 'bsd68' # test set, 'bsd68' | 'cbsd68' | 'set12' + need_degradation = True # default: True + x8 = False # default: False, x8 to boost performance + show_img = False # default: Falsedefault: False + + + + + task_current = 'dn' # 'dn' for denoising | 'sr' for super-resolution + sf = 1 # unused for denoising + if 'color' in model_name: + n_channels = 3 # 3 for color image + else: + n_channels = 1 # 1 for grayscale image + if 'clip' in model_name: + use_clip = True # clip the intensities into range of [0, 1] + else: + use_clip = False + model_pool = 'model_zoo' # fixed + testsets = 'testsets' # fixed + results = 'results' # fixed + result_name = testset_name + '_' + model_name + border = sf if task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM + model_path = os.path.join(model_pool, model_name+'.pth') + + # ---------------------------------------- + # L_path, E_path, H_path + # ---------------------------------------- + + L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images + H_path = L_path # H_path, for High-quality images + E_path = os.path.join(results, result_name) # E_path, for Estimated images + util.mkdir(E_path) + + if H_path == L_path: + need_degradation = True + logger_name = result_name + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + + need_H = True if H_path is not None else False + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_dncnn import FDnCNN as net + model = net(in_nc=n_channels+1, out_nc=n_channels, nc=64, nb=20, act_mode='R') + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + number_parameters = sum(map(lambda x: x.numel(), model.parameters())) + logger.info('Params number: {}'.format(number_parameters)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + + logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(model_name, noise_level_img, noise_level_model)) + logger.info(L_path) + L_paths = util.get_image_paths(L_path) + H_paths = util.get_image_paths(H_path) if need_H else None + + for idx, img in enumerate(L_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + + img_name, ext = os.path.splitext(os.path.basename(img)) + # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + img_L = util.uint2single(img_L) + + if need_degradation: # degradation process + np.random.seed(seed=0) # for reproducibility + img_L += np.random.normal(0, noise_level_img/255., img_L.shape) + if use_clip: + img_L = util.uint2single(util.single2uint(img_L)) + + util.imshow(util.single2uint(img_L), title='Noisy image with noise level {}'.format(noise_level_img)) if show_img else None + + img_L = util.single2tensor4(img_L) + noise_level_map = torch.ones((1, 1, img_L.size(2), img_L.size(3)), dtype=torch.float).mul_(noise_level_model/255.) + img_L = torch.cat((img_L, noise_level_map), dim=1) + img_L = img_L.to(device) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + if not x8: + img_E = model(img_L) + else: + img_E = utils_model.test_mode(model, img_L, mode=3) + + img_E = util.tensor2uint(img_E) + + if need_H: + + # -------------------------------- + # (3) img_H + # -------------------------------- + + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels) + img_H = img_H.squeeze() + + # -------------------------------- + # PSNR and SSIM + # -------------------------------- + + psnr = util.calculate_psnr(img_E, img_H, border=border) + ssim = util.calculate_ssim(img_E, img_H, border=border) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None + + # ------------------------------------ + # save results + # ------------------------------------ + + util.imsave(img_E, os.path.join(E_path, img_name+ext)) + + if need_H: + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + logger.info('Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim)) + +if __name__ == '__main__': + + main() diff --git a/main_test_ffdnet.py b/main_test_ffdnet.py new file mode 100644 index 00000000..9407259b --- /dev/null +++ b/main_test_ffdnet.py @@ -0,0 +1,198 @@ +import os.path +import logging + +import numpy as np +from collections import OrderedDict + +import torch + +from utils import utils_logger +from utils import utils_image as util + + +''' +Spyder (Python 3.6) +PyTorch 1.1.0 +Windows 10 or Linux + +Kai Zhang (cskaizhang@gmail.com) +github: https://github.com/cszn/KAIR + https://github.com/cszn/FFDNet + +@article{zhang2018ffdnet, + title={FFDNet: Toward a fast and flexible solution for CNN-based image denoising}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + journal={IEEE Transactions on Image Processing}, + volume={27}, + number={9}, + pages={4608--4622}, + year={2018}, + publisher={IEEE} +} + +% If you have any question, please feel free to contact with me. +% Kai Zhang (e-mail: cskaizhang@gmail.com; github: https://github.com/cszn) + +by Kai Zhang (12/Dec./2019) +''' + +""" +# -------------------------------------------- +|--model_zoo # model_zoo + |--ffdnet_gray # model_name, for color images + |--ffdnet_color + |--ffdnet_color_clip # for clipped uint8 color images + |--ffdnet_gray_clip +|--testset # testsets + |--set12 # testset_name + |--bsd68 + |--cbsd68 +|--results # results + |--set12_ffdnet_gray # result_name = testset_name + '_' + model_name + |--set12_ffdnet_color + |--cbsd68_ffdnet_color_clip +# -------------------------------------------- +""" + + +def main(): + + # ---------------------------------------- + # Preparation + # ---------------------------------------- + + noise_level_img = 15 # noise level for noisy image + noise_level_model = noise_level_img # noise level for model + model_name = 'ffdnet_gray' # 'ffdnet_gray' | 'ffdnet_color' | 'ffdnet_color_clip' | 'ffdnet_gray_clip' + testset_name = 'bsd68' # test set, 'bsd68' | 'cbsd68' | 'set12' + need_degradation = True # default: True + show_img = False # default: False + + + + + task_current = 'dn' # 'dn' for denoising | 'sr' for super-resolution + sf = 1 # unused for denoising + if 'color' in model_name: + n_channels = 3 # setting for color image + nc = 96 # setting for color image + nb = 12 # setting for color image + else: + n_channels = 1 # setting for grayscale image + nc = 64 # setting for grayscale image + nb = 15 # setting for grayscale image + if 'clip' in model_name: + use_clip = True # clip the intensities into range of [0, 1] + else: + use_clip = False + model_pool = 'model_zoo' # fixed + testsets = 'testsets' # fixed + results = 'results' # fixed + result_name = testset_name + '_' + model_name + border = sf if task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM + model_path = os.path.join(model_pool, model_name+'.pth') + + # ---------------------------------------- + # L_path, E_path, H_path + # ---------------------------------------- + + L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images + H_path = L_path # H_path, for High-quality images + E_path = os.path.join(results, result_name) # E_path, for Estimated images + util.mkdir(E_path) + + if H_path == L_path: + need_degradation = True + logger_name = result_name + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + + need_H = True if H_path is not None else False + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_ffdnet import FFDNet as net + model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R') + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + + logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(model_name, noise_level_img, noise_level_model)) + logger.info(L_path) + L_paths = util.get_image_paths(L_path) + H_paths = util.get_image_paths(H_path) if need_H else None + + for idx, img in enumerate(L_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + + img_name, ext = os.path.splitext(os.path.basename(img)) + # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + img_L = util.uint2single(img_L) + + if need_degradation: # degradation process + np.random.seed(seed=0) # for reproducibility + img_L += np.random.normal(0, noise_level_img/255., img_L.shape) + if use_clip: + img_L = util.uint2single(util.single2uint(img_L)) + + util.imshow(util.single2uint(img_L), title='Noisy image with noise level {}'.format(noise_level_img)) if show_img else None + + img_L = util.single2tensor4(img_L) + img_L = img_L.to(device) + + sigma = torch.full((1,1,1,1), noise_level_model/255.).type_as(img_L) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + img_E = model(img_L, sigma) + img_E = util.tensor2uint(img_E) + + if need_H: + + # -------------------------------- + # (3) img_H + # -------------------------------- + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels) + img_H = img_H.squeeze() + + # -------------------------------- + # PSNR and SSIM + # -------------------------------- + + psnr = util.calculate_psnr(img_E, img_H, border=border) + ssim = util.calculate_ssim(img_E, img_H, border=border) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None + + # ------------------------------------ + # save results + # ------------------------------------ + + util.imsave(img_E, os.path.join(E_path, img_name+ext)) + + if need_H: + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + logger.info('Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim)) + +if __name__ == '__main__': + + main() diff --git a/main_test_imdn.py b/main_test_imdn.py new file mode 100644 index 00000000..1c597a00 --- /dev/null +++ b/main_test_imdn.py @@ -0,0 +1,212 @@ +import os.path +import logging +import re + +import numpy as np +from collections import OrderedDict + +import torch + +from utils import utils_logger +from utils import utils_image as util +from utils import utils_model + + +''' +Spyder (Python 3.6) +PyTorch 1.1.0 +Windows 10 or Linux + +Kai Zhang (cskaizhang@gmail.com) +github: https://github.com/cszn/KAIR + +If you have any question, please feel free to contact with me. +Kai Zhang (e-mail: cskaizhang@gmail.com) +(github: https://github.com/cszn/KAIR) + +by Kai Zhang (12/Dec./2019) +''' + +""" +# -------------------------------------------- +# simplified information multi-distillation +# network (IMDN) for SR +# -------------------------------------------- +@inproceedings{hui2019lightweight, + title={Lightweight Image Super-Resolution with Information Multi-distillation Network}, + author={Hui, Zheng and Gao, Xinbo and Yang, Yunchu and Wang, Xiumei}, + booktitle={Proceedings of the 27th ACM International Conference on Multimedia (ACM MM)}, + pages={2024--2032}, + year={2019} +} +@inproceedings{zhang2019aim, + title={AIM 2019 Challenge on Constrained Super-Resolution: Methods and Results}, + author={Kai Zhang and Shuhang Gu and Radu Timofte and others}, + booktitle={IEEE International Conference on Computer Vision Workshops}, + year={2019} +} +# -------------------------------------------- +|--model_zoo # model_zoo + |--imdn_x4 # model_name, optimized for PSNR +|--testset # testsets + |--set5 # testset_name + |--srbsd68 +|--results # results + |--set5_imdn_x4 # result_name = testset_name + '_' + model_name +# -------------------------------------------- +""" + + +def main(): + + # ---------------------------------------- + # Preparation + # ---------------------------------------- + + model_name = 'imdn_x4' # 'imdn_x4' + testset_name = 'set5' # test set, 'set5' | 'srbsd68' + need_degradation = True # default: True + x8 = False # default: False, x8 to boost performance, default: False + sf = [int(s) for s in re.findall(r'\d+', model_name)][0] # scale factor + show_img = False # default: False + + + + + task_current = 'sr' # 'dn' for denoising | 'sr' for super-resolution + n_channels = 3 # fixed + model_pool = 'model_zoo' # fixed + testsets = 'testsets' # fixed + results = 'results' # fixed + noise_level_img = 0 # fixed: 0, noise level for LR image + result_name = testset_name + '_' + model_name + border = sf if task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM + model_path = os.path.join(model_pool, model_name+'.pth') + + # ---------------------------------------- + # L_path, E_path, H_path + # ---------------------------------------- + + L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images + H_path = L_path # H_path, for High-quality images + E_path = os.path.join(results, result_name) # E_path, for Estimated images + util.mkdir(E_path) + + if H_path == L_path: + need_degradation = True + logger_name = result_name + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + + need_H = True if H_path is not None else False + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_imdn import IMDN as net + model = net(in_nc=n_channels, out_nc=n_channels, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle') + + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + number_parameters = sum(map(lambda x: x.numel(), model.parameters())) + logger.info('Params number: {}'.format(number_parameters)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnr_y'] = [] + test_results['ssim_y'] = [] + + logger.info('model_name:{}, image sigma:{}'.format(model_name, noise_level_img)) + logger.info(L_path) + L_paths = util.get_image_paths(L_path) + H_paths = util.get_image_paths(H_path) if need_H else None + + for idx, img in enumerate(L_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + + img_name, ext = os.path.splitext(os.path.basename(img)) + # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + img_L = util.uint2single(img_L) + + # degradation process, bicubic downsampling + if need_degradation: + img_L = util.modcrop(img_L, sf) + img_L = util.imresize_np(img_L, 1/sf) + # img_L = util.uint2single(util.single2uint(img_L)) + # np.random.seed(seed=0) # for reproducibility + # img_L += np.random.normal(0, noise_level_img/255., img_L.shape) + + util.imshow(util.single2uint(img_L), title='LR image with noise level {}'.format(noise_level_img)) if show_img else None + + img_L = util.single2tensor4(img_L) + img_L = img_L.to(device) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + if not x8: + img_E = model(img_L) + else: + img_E = utils_model.test_mode(model, img_L, mode=3, sf=sf) + + img_E = util.tensor2uint(img_E) + + if need_H: + + # -------------------------------- + # (3) img_H + # -------------------------------- + + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels) + img_H = img_H.squeeze() + img_H = util.modcrop(img_H, sf) + + # -------------------------------- + # PSNR and SSIM + # -------------------------------- + + psnr = util.calculate_psnr(img_E, img_H, border=border) + ssim = util.calculate_ssim(img_E, img_H, border=border) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None + + if np.ndim(img_H) == 3: # RGB image + img_E_y = util.rgb2ycbcr(img_E, only_y=True) + img_H_y = util.rgb2ycbcr(img_H, only_y=True) + psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border) + ssim_y = util.calculate_ssim(img_E_y, img_H_y, border=border) + test_results['psnr_y'].append(psnr_y) + test_results['ssim_y'].append(ssim_y) + + # ------------------------------------ + # save results + # ------------------------------------ + + util.imsave(img_E, os.path.join(E_path, img_name+'.png')) + + if need_H: + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + logger.info('Average PSNR/SSIM(RGB) - {} - x{} --PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, sf, ave_psnr, ave_ssim)) + if np.ndim(img_H) == 3: + ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) + ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) + logger.info('Average PSNR/SSIM( Y ) - {} - x{} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, sf, ave_psnr_y, ave_ssim_y)) + +if __name__ == '__main__': + + main() diff --git a/main_test_msrresnet.py b/main_test_msrresnet.py new file mode 100644 index 00000000..207498bc --- /dev/null +++ b/main_test_msrresnet.py @@ -0,0 +1,213 @@ +import os.path +import logging +import re + +import numpy as np +from collections import OrderedDict + +import torch + +from utils import utils_logger +from utils import utils_image as util +from utils import utils_model + + +''' +Spyder (Python 3.6) +PyTorch 1.1.0 +Windows 10 or Linux + +Kai Zhang (cskaizhang@gmail.com) +github: https://github.com/cszn/KAIR + +If you have any question, please feel free to contact with me. +Kai Zhang (e-mail: cskaizhang@gmail.com) +(github: https://github.com/cszn/KAIR) + +by Kai Zhang (12/Dec./2019) +''' + +""" +# -------------------------------------------- +testing demo for RRDB-ESRGAN +https://github.com/xinntao/ESRGAN +@inproceedings{wang2018esrgan, + title={Esrgan: Enhanced super-resolution generative adversarial networks}, + author={Wang, Xintao and Yu, Ke and Wu, Shixiang and Gu, Jinjin and Liu, Yihao and Dong, Chao and Qiao, Yu and Change Loy, Chen}, + booktitle={European Conference on Computer Vision (ECCV)}, + pages={0--0}, + year={2018} +} +@inproceedings{ledig2017photo, + title={Photo-realistic single image super-resolution using a generative adversarial network}, + author={Ledig, Christian and Theis, Lucas and Husz{\'a}r, Ferenc and Caballero, Jose and Cunningham, Andrew and Acosta, Alejandro and Aitken, Andrew and Tejani, Alykhan and Totz, Johannes and Wang, Zehan and others}, + booktitle={IEEE conference on computer vision and pattern recognition}, + pages={4681--4690}, + year={2017} +} +# -------------------------------------------- +|--model_zoo # model_zoo + |--msrresnet_x4_gan # model_name, optimized for perceptual quality + |--msrresnet_x4_psnr # model_name, optimized for PSNR +|--testset # testsets + |--set5 # testset_name + |--srbsd68 +|--results # results + |--set5_msrresnet_x4_gan # result_name = testset_name + '_' + model_name + |--set5_msrresnet_x4_psnr +# -------------------------------------------- +""" + + +def main(): + + # ---------------------------------------- + # Preparation + # ---------------------------------------- + + model_name = 'msrresnet_x4_psnr' # 'msrresnet_x4_gan' | 'msrresnet_x4_psnr' + testset_name = 'set5' # test set, 'set5' | 'srbsd68' + need_degradation = True # default: True + x8 = False # default: False, x8 to boost performance, default: False + sf = [int(s) for s in re.findall(r'\d+', model_name)][0] # scale factor + show_img = False # default: False + + + + + task_current = 'sr' # 'dn' for denoising | 'sr' for super-resolution + n_channels = 3 # fixed + model_pool = 'model_zoo' # fixed + testsets = 'testsets' # fixed + results = 'results' # fixed + noise_level_img = 0 # fixed: 0, noise level for LR image + result_name = testset_name + '_' + model_name + border = sf if task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM + model_path = os.path.join(model_pool, model_name+'.pth') + + # ---------------------------------------- + # L_path, E_path, H_path + # ---------------------------------------- + + L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images + H_path = L_path # H_path, for High-quality images + E_path = os.path.join(results, result_name) # E_path, for Estimated images + util.mkdir(E_path) + + if H_path == L_path: + need_degradation = True + logger_name = result_name + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + + need_H = True if H_path is not None else False + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_msrresnet import MSRResNet1 as net + model = net(in_nc=n_channels, out_nc=n_channels, nc=64, nb=16, upscale=4) + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + number_parameters = sum(map(lambda x: x.numel(), model.parameters())) + logger.info('Params number: {}'.format(number_parameters)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnr_y'] = [] + test_results['ssim_y'] = [] + + logger.info('model_name:{}, image sigma:{}'.format(model_name, noise_level_img)) + logger.info(L_path) + L_paths = util.get_image_paths(L_path) + H_paths = util.get_image_paths(H_path) if need_H else None + + for idx, img in enumerate(L_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + + img_name, ext = os.path.splitext(os.path.basename(img)) + # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + img_L = util.uint2single(img_L) + + # degradation process, bicubic downsampling + if need_degradation: + img_L = util.modcrop(img_L, sf) + img_L = util.imresize_np(img_L, 1/sf) + # img_L = util.uint2single(util.single2uint(img_L)) + # np.random.seed(seed=0) # for reproducibility + # img_L += np.random.normal(0, noise_level_img/255., img_L.shape) + + util.imshow(util.single2uint(img_L), title='LR image with noise level {}'.format(noise_level_img)) if show_img else None + + img_L = util.single2tensor4(img_L) + img_L = img_L.to(device) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + if not x8: + img_E = model(img_L) + else: + img_E = utils_model.test_mode(model, img_L, mode=3, sf=sf) + + img_E = util.tensor2uint(img_E) + + if need_H: + + # -------------------------------- + # (3) img_H + # -------------------------------- + + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels) + img_H = img_H.squeeze() + img_H = util.modcrop(img_H, sf) + + # -------------------------------- + # PSNR and SSIM + # -------------------------------- + + psnr = util.calculate_psnr(img_E, img_H, border=border) + ssim = util.calculate_ssim(img_E, img_H, border=border) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None + + if np.ndim(img_H) == 3: # RGB image + img_E_y = util.rgb2ycbcr(img_E, only_y=True) + img_H_y = util.rgb2ycbcr(img_H, only_y=True) + psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border) + ssim_y = util.calculate_ssim(img_E_y, img_H_y, border=border) + test_results['psnr_y'].append(psnr_y) + test_results['ssim_y'].append(ssim_y) + + # ------------------------------------ + # save results + # ------------------------------------ + + util.imsave(img_E, os.path.join(E_path, img_name+'.png')) + + if need_H: + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + logger.info('Average PSNR/SSIM(RGB) - {} - x{} --PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, sf, ave_psnr, ave_ssim)) + if np.ndim(img_H) == 3: + ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) + ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) + logger.info('Average PSNR/SSIM( Y ) - {} - x{} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, sf, ave_psnr_y, ave_ssim_y)) + +if __name__ == '__main__': + + main() diff --git a/main_test_rrdb.py b/main_test_rrdb.py new file mode 100644 index 00000000..9ab8f72f --- /dev/null +++ b/main_test_rrdb.py @@ -0,0 +1,205 @@ +import os.path +import logging +import re + +import numpy as np +from collections import OrderedDict + +import torch + +from utils import utils_logger +from utils import utils_image as util +from utils import utils_model + + +''' +Spyder (Python 3.6) +PyTorch 1.1.0 +Windows 10 or Linux + +Kai Zhang (cskaizhang@gmail.com) +github: https://github.com/cszn/KAIR + +If you have any question, please feel free to contact with me. +Kai Zhang (e-mail: cskaizhang@gmail.com) +(github: https://github.com/cszn/KAIR) + +by Kai Zhang (12/Dec./2019) +''' + +""" +# -------------------------------------------- +testing demo for RRDB-ESRGAN +https://github.com/xinntao/ESRGAN +@inproceedings{wang2018esrgan, + title={Esrgan: Enhanced super-resolution generative adversarial networks}, + author={Wang, Xintao and Yu, Ke and Wu, Shixiang and Gu, Jinjin and Liu, Yihao and Dong, Chao and Qiao, Yu and Change Loy, Chen}, + booktitle={European Conference on Computer Vision (ECCV)}, + pages={0--0}, + year={2018} +} +# -------------------------------------------- +|--model_zoo # model_zoo + |--rrdb_x4_esrgan # model_name, optimized for perceptual quality + |--rrdb_x4_psnr # model_name, optimized for PSNR +|--testset # testsets + |--set5 # testset_name + |--srbsd68 +|--results # results + |--set5_rrdb_x4_esrgan# result_name = testset_name + '_' + model_name + |--set5_rrdb_x4_psnr +# -------------------------------------------- +""" + + +def main(): + + # ---------------------------------------- + # Preparation + # ---------------------------------------- + + model_name = 'rrdb_x4_esrgan' # 'rrdb_x4_esrgan' | 'rrdb_x4_psnr' + testset_name = 'set5' # test set, 'set5' | 'srbsd68' + need_degradation = True # default: True + x8 = False # default: False, x8 to boost performance + sf = [int(s) for s in re.findall(r'\d+', model_name)][0] # scale factor + show_img = False # default: False + + + + + task_current = 'sr' # 'dn' for denoising | 'sr' for super-resolution + n_channels = 3 # fixed + model_pool = 'model_zoo' # fixed + testsets = 'testsets' # fixed + results = 'results' # fixed + noise_level_img = 0 # fixed: 0, noise level for LR image + result_name = testset_name + '_' + model_name + border = sf if task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM + model_path = os.path.join(model_pool, model_name+'.pth') + + # ---------------------------------------- + # L_path, E_path, H_path + # ---------------------------------------- + + L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images + H_path = L_path # H_path, for High-quality images + E_path = os.path.join(results, result_name) # E_path, for Estimated images + util.mkdir(E_path) + + if H_path == L_path: + need_degradation = True + logger_name = result_name + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + + need_H = True if H_path is not None else False + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_rrdb import RRDB as net + model = net(in_nc=n_channels, out_nc=n_channels, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv') + model.load_state_dict(torch.load(model_path), strict=False) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + number_parameters = sum(map(lambda x: x.numel(), model.parameters())) + logger.info('Params number: {}'.format(number_parameters)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnr_y'] = [] + test_results['ssim_y'] = [] + + logger.info('model_name:{}, image sigma:{}'.format(model_name, noise_level_img)) + logger.info(L_path) + L_paths = util.get_image_paths(L_path) + H_paths = util.get_image_paths(H_path) if need_H else None + + for idx, img in enumerate(L_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + + img_name, ext = os.path.splitext(os.path.basename(img)) + # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + img_L = util.uint2single(img_L) + + # degradation process, bicubic downsampling + Gaussian noise + if need_degradation: + img_L = util.modcrop(img_L, sf) + img_L = util.imresize_np(img_L, 1/sf) + # np.random.seed(seed=0) # for reproducibility + # img_L += np.random.normal(0, noise_level_img/255., img_L.shape) + + util.imshow(util.single2uint(img_L), title='LR image with noise level {}'.format(noise_level_img)) if show_img else None + + img_L = util.single2tensor4(img_L) + img_L = img_L.to(device) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + if not x8: + img_E = model(img_L) + else: + img_E = utils_model.test_mode(model, img_L, mode=3, sf=sf) + + img_E = util.tensor2uint(img_E) + + if need_H: + + # -------------------------------- + # (3) img_H + # -------------------------------- + + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels) + img_H = img_H.squeeze() + img_H = util.modcrop(img_H, sf) + + # -------------------------------- + # PSNR and SSIM + # -------------------------------- + + psnr = util.calculate_psnr(img_E, img_H, border=border) + ssim = util.calculate_ssim(img_E, img_H, border=border) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None + + if np.ndim(img_H) == 3: # RGB image + img_E_y = util.rgb2ycbcr(img_E, only_y=True) + img_H_y = util.rgb2ycbcr(img_H, only_y=True) + psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border) + ssim_y = util.calculate_ssim(img_E_y, img_H_y, border=border) + test_results['psnr_y'].append(psnr_y) + test_results['ssim_y'].append(ssim_y) + + # ------------------------------------ + # save results + # ------------------------------------ + + util.imsave(img_E, os.path.join(E_path, img_name+'.png')) + + if need_H: + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + logger.info('Average PSNR/SSIM(RGB) - {} - x{} --PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, sf, ave_psnr, ave_ssim)) + if np.ndim(img_H) == 3: + ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) + ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) + logger.info('Average PSNR/SSIM( Y ) - {} - x{} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, sf, ave_psnr_y, ave_ssim_y)) + +if __name__ == '__main__': + + main() diff --git a/main_test_srmd.py b/main_test_srmd.py new file mode 100644 index 00000000..8a72ab5c --- /dev/null +++ b/main_test_srmd.py @@ -0,0 +1,233 @@ +import os.path +import logging +import re + +import numpy as np +from collections import OrderedDict +from scipy.io import loadmat + +import torch + +from utils import utils_deblur +from utils import utils_sisr as sr +from utils import utils_logger +from utils import utils_image as util +from utils import utils_model + + +''' +Spyder (Python 3.6) +PyTorch 1.1.0 +Windows 10 or Linux + +Kai Zhang (cskaizhang@gmail.com) +github: https://github.com/cszn/KAIR + https://github.com/cszn/SRMD + +@inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} +} + +% If you have any question, please feel free to contact with me. +% Kai Zhang (e-mail: cskaizhang@gmail.com; github: https://github.com/cszn) + +by Kai Zhang (12/Dec./2019) +''' + +""" +# -------------------------------------------- +|--model_zoo # model_zoo + |--srmdnf_x2 # model_name, for noise-free LR image SR + |--srmdnf_x3 + |--srmdnf_x4 + |--srmd_x2 # model_name, for noisy LR image + |--srmd_x3 + |--srmd_x4 +|--testset # testsets + |--set5 # testset_name + |--srbsd68 +|--results # results + |--set5_srmdnf_x2 # result_name = testset_name + '_' + model_name + |--set5_srmdnf_x3 + |--set5_srmdnf_x4 + |--set5_srmd_x2 + |--srbsd68_srmd_x2 +# -------------------------------------------- +""" + + +def main(): + + # ---------------------------------------- + # Preparation + # ---------------------------------------- + + noise_level_img = 0 # default: 0, noise level for LR image + noise_level_model = noise_level_img # noise level for model + model_name = 'srmdnf_x4' # 'srmd_x2' | 'srmd_x3' | 'srmd_x4' | 'srmdnf_x2' | 'srmdnf_x3' | 'srmdnf_x4' + testset_name = 'set5' # test set, 'set5' | 'srbsd68' + sf = [int(s) for s in re.findall(r'\d+', model_name)][0] # scale factor + x8 = False # default: False, x8 to boost performance + need_degradation = True # default: True, use degradation model to generate LR image + show_img = False # default: False + + + + + srmd_pca_path = os.path.join('kernels', 'srmd_pca_matlab.mat') + task_current = 'sr' # 'dn' for denoising | 'sr' for super-resolution + n_channels = 3 # fixed + in_nc = 18 if 'nf' in model_name else 19 + nc = 128 # fixed, number of channels + nb = 12 # fixed, number of conv layers + model_pool = 'model_zoo' # fixed + testsets = 'testsets' # fixed + results = 'results' # fixed + result_name = testset_name + '_' + model_name + border = sf if task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM + model_path = os.path.join(model_pool, model_name+'.pth') + + # ---------------------------------------- + # L_path, E_path, H_path + # ---------------------------------------- + + L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images + H_path = L_path # H_path, for High-quality images + E_path = os.path.join(results, result_name) # E_path, for Estimated images + util.mkdir(E_path) + + if H_path == L_path: + need_degradation = True + logger_name = result_name + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + + need_H = True if H_path is not None else False + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_srmd import SRMD as net + model = net(in_nc=in_nc, out_nc=n_channels, nc=nc, nb=nb, upscale=sf, act_mode='R', upsample_mode='pixelshuffle') + model.load_state_dict(torch.load(model_path), strict=False) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + number_parameters = sum(map(lambda x: x.numel(), model.parameters())) + logger.info('Params number: {}'.format(number_parameters)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnr_y'] = [] + test_results['ssim_y'] = [] + + logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(model_name, noise_level_img, noise_level_model)) + logger.info(L_path) + L_paths = util.get_image_paths(L_path) + H_paths = util.get_image_paths(H_path) if need_H else None + + # ---------------------------------------- + # kernel and PCA reduced feature + # ---------------------------------------- + + # kernel = sr.anisotropic_Gaussian(ksize=15, theta=np.pi, l1=4, l2=4) + kernel = utils_deblur.fspecial('gaussian', 15, 0.01) # Gaussian kernel, delta kernel 0.01 + + P = loadmat(srmd_pca_path)['P'] + degradation_vector = np.dot(P, np.reshape(kernel, (-1), order="F")) + if 'nf' not in model_name: # noise-free SR + degradation_vector = np.append(degradation_vector, noise_level_model/255.) + degradation_vector = torch.from_numpy(degradation_vector).view(1, -1, 1, 1).float() + + for idx, img in enumerate(L_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + + img_name, ext = os.path.splitext(os.path.basename(img)) + # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + img_L = util.uint2single(img_L) + + # degradation process, blur + bicubic downsampling + Gaussian noise + if need_degradation: + img_L = util.modcrop(img_L, sf) + img_L = sr.srmd_degradation(img_L, kernel, sf) # equivalent to bicubic degradation if kernel is a delta kernel + np.random.seed(seed=0) # for reproducibility + img_L += np.random.normal(0, noise_level_img/255., img_L.shape) + + util.imshow(util.single2uint(img_L), title='LR image with noise level {}'.format(noise_level_img)) if show_img else None + + img_L = util.single2tensor4(img_L) + degradation_map = degradation_vector.repeat(1, 1, img_L.size(-2), img_L.size(-1)) + img_L = torch.cat((img_L, degradation_map), dim=1) + img_L = img_L.to(device) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + if not x8: + img_E = model(img_L) + else: + img_E = utils_model.test_mode(model, img_L, mode=3, sf=sf) + + img_E = util.tensor2uint(img_E) + + if need_H: + + # -------------------------------- + # (3) img_H + # -------------------------------- + + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels) + img_H = img_H.squeeze() + img_H = util.modcrop(img_H, sf) + + # -------------------------------- + # PSNR and SSIM + # -------------------------------- + + psnr = util.calculate_psnr(img_E, img_H, border=border) + ssim = util.calculate_ssim(img_E, img_H, border=border) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None + + if np.ndim(img_H) == 3: # RGB image + img_E_y = util.rgb2ycbcr(img_E, only_y=True) + img_H_y = util.rgb2ycbcr(img_H, only_y=True) + psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border) + ssim_y = util.calculate_ssim(img_E_y, img_H_y, border=border) + test_results['psnr_y'].append(psnr_y) + test_results['ssim_y'].append(ssim_y) + + # ------------------------------------ + # save results + # ------------------------------------ + + util.imsave(img_E, os.path.join(E_path, img_name+'.png')) + + if need_H: + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + logger.info('Average PSNR/SSIM(RGB) - {} - x{} --PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, sf, ave_psnr, ave_ssim)) + if np.ndim(img_H) == 3: + ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) + ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) + logger.info('Average PSNR/SSIM( Y ) - {} - x{} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, sf, ave_psnr_y, ave_ssim_y)) + +if __name__ == '__main__': + + main() diff --git a/main_train_dncnn.py b/main_train_dncnn.py new file mode 100644 index 00000000..7cf0c10c --- /dev/null +++ b/main_train_dncnn.py @@ -0,0 +1,250 @@ +import os.path +import math +import argparse +import time +import random +import numpy as np +from collections import OrderedDict +import logging +import torch +from torch.utils.data import DataLoader + + +from utils import utils_logger +from utils import utils_image as util +from utils import utils_option as option + +from data.select_dataset import define_Dataset +from models.select_model import define_Model + + +''' +# -------------------------------------------- +# training code for DnCNN +# -------------------------------------------- +# Kai Zhang (cskaizhang@gmail.com) +# github: https://github.com/cszn/KAIR +# https://github.com/cszn/DnCNN +# +# Reference: +@article{zhang2017beyond, + title={Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising}, + author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei}, + journal={IEEE Transactions on Image Processing}, + volume={26}, + number={7}, + pages={3142--3155}, + year={2017}, + publisher={IEEE} +} +# -------------------------------------------- +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +def main(json_path='options/train_dncnn.json'): + + ''' + # ---------------------------------------- + # Step--1 (prepare opt) + # ---------------------------------------- + ''' + + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, default=json_path, help='Path to option JSON file.') + + opt = option.parse(parser.parse_args().opt, is_train=True) + util.mkdirs((path for key, path in opt['path'].items() if 'pretrained' not in key)) + + # ---------------------------------------- + # update opt + # ---------------------------------------- + # -->-->-->-->-->-->-->-->-->-->-->-->-->- + init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') + opt['path']['pretrained_netG'] = init_path_G + current_step = init_iter + + border = 0 + # --<--<--<--<--<--<--<--<--<--<--<--<--<- + + # ---------------------------------------- + # save opt to a '../option.json' file + # ---------------------------------------- + option.save(opt) + + # ---------------------------------------- + # return None for missing key + # ---------------------------------------- + opt = option.dict_to_nonedict(opt) + + # ---------------------------------------- + # configure logger + # ---------------------------------------- + logger_name = 'train' + utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info(option.dict2str(opt)) + + # ---------------------------------------- + # seed + # ---------------------------------------- + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + logger.info('Random seed: {}'.format(seed)) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + ''' + # ---------------------------------------- + # Step--2 (creat dataloader) + # ---------------------------------------- + ''' + + # ---------------------------------------- + # 1) create_dataset + # 2) creat_dataloader for train and test + # ---------------------------------------- + dataset_type = opt['datasets']['train']['dataset_type'] + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = define_Dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) + logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) + train_loader = DataLoader(train_set, + batch_size=dataset_opt['dataloader_batch_size'], + shuffle=dataset_opt['dataloader_shuffle'], + num_workers=dataset_opt['dataloader_num_workers'], + drop_last=True, + pin_memory=True) + elif phase == 'test': + test_set = define_Dataset(dataset_opt) + test_loader = DataLoader(test_set, batch_size=1, + shuffle=False, num_workers=1, + drop_last=False, pin_memory=True) + else: + raise NotImplementedError("Phase [%s] is not recognized." % phase) + + ''' + # ---------------------------------------- + # Step--3 (initialize model) + # ---------------------------------------- + ''' + + model = define_Model(opt) + + if opt['merge_bn'] and current_step > opt['merge_bn_startpoint']: + logger.info('^_^ -----merging bnorm----- ^_^') + model.merge_bnorm_test() + + logger.info(model.info_network()) + model.init_train() + logger.info(model.info_params()) + + ''' + # ---------------------------------------- + # Step--4 (main training) + # ---------------------------------------- + ''' + + for epoch in range(1000000): # keep running + for i, train_data in enumerate(train_loader): + + current_step += 1 + + if dataset_type == 'dnpatch' and current_step % 20000 == 0: # for 'train400' + train_loader.dataset.update_data() + + # ------------------------------- + # 1) update learning rate + # ------------------------------- + model.update_learning_rate(current_step) + + # ------------------------------- + # 2) feed patch pairs + # ------------------------------- + model.feed_data(train_data) + + # ------------------------------- + # 3) optimize parameters + # ------------------------------- + model.optimize_parameters(current_step) + + # ------------------------------- + # merge bnorm + # ------------------------------- + if opt['merge_bn'] and opt['merge_bn_startpoint'] == current_step: + logger.info('^_^ -----merging bnorm----- ^_^') + model.merge_bnorm_train() + model.print_network() + + # ------------------------------- + # 4) training information + # ------------------------------- + if current_step % opt['train']['checkpoint_print'] == 0: + logs = model.current_log() # such as loss + message = ' '.format(epoch, current_step, model.current_learning_rate()) + for k, v in logs.items(): # merge log information into message + message += '{:s}: {:.3e} '.format(k, v) + logger.info(message) + + # ------------------------------- + # 5) save model + # ------------------------------- + if current_step % opt['train']['checkpoint_save'] == 0: + logger.info('Saving the model.') + model.save(current_step) + + # ------------------------------- + # 6) testing + # ------------------------------- + if current_step % opt['train']['checkpoint_test'] == 0: + + avg_psnr = 0.0 + idx = 0 + + for test_data in test_loader: + idx += 1 + image_name_ext = os.path.basename(test_data['L_path'][0]) + img_name, ext = os.path.splitext(image_name_ext) + + img_dir = os.path.join(opt['path']['images'], img_name) + util.mkdir(img_dir) + + model.feed_data(test_data) + model.test() + + visuals = model.current_visuals() + E_img = util.tensor2uint(visuals['E']) + H_img = util.tensor2uint(visuals['H']) + + # ----------------------- + # save estimated image E + # ----------------------- + save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) + util.imsave(E_img, save_img_path) + + # ----------------------- + # calculate PSNR + # ----------------------- + current_psnr = util.calculate_psnr(E_img, H_img, border=border) + + logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr)) + + avg_psnr += current_psnr + + avg_psnr = avg_psnr / idx + + # testing log + logger.info('-->-->-->-->-->-->-->-->-->-->-->-->- + init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') + opt['path']['pretrained_netG'] = init_path_G + current_step = init_iter + + border = opt['scale'] + # --<--<--<--<--<--<--<--<--<--<--<--<--<- + + # ---------------------------------------- + # save opt to a '../option.json' file + # ---------------------------------------- + option.save(opt) + + # ---------------------------------------- + # return None for missing key + # ---------------------------------------- + opt = option.dict_to_nonedict(opt) + + # ---------------------------------------- + # configure logger + # ---------------------------------------- + logger_name = 'train' + utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info(option.dict2str(opt)) + + # ---------------------------------------- + # seed + # ---------------------------------------- + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + logger.info('Random seed: {}'.format(seed)) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + ''' + # ---------------------------------------- + # Step--2 (creat dataloader) + # ---------------------------------------- + ''' + + # ---------------------------------------- + # 1) create_dataset + # 2) creat_dataloader for train and test + # ---------------------------------------- + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = define_Dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) + logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) + train_loader = DataLoader(train_set, + batch_size=dataset_opt['dataloader_batch_size'], + shuffle=dataset_opt['dataloader_shuffle'], + num_workers=dataset_opt['dataloader_num_workers'], + drop_last=True, + pin_memory=True) + elif phase == 'test': + test_set = define_Dataset(dataset_opt) + test_loader = DataLoader(test_set, batch_size=1, + shuffle=False, num_workers=1, + drop_last=False, pin_memory=True) + else: + raise NotImplementedError("Phase [%s] is not recognized." % phase) + + ''' + # ---------------------------------------- + # Step--3 (initialize model) + # ---------------------------------------- + ''' + + model = define_Model(opt) + +# if opt['merge_bn'] and current_step > opt['merge_bn_startpoint']: +# logger.info('^_^ -----merging bnorm----- ^_^') +# model.merge_bnorm_test() + + logger.info(model.info_network()) + model.init_train() + logger.info(model.info_params()) + + ''' + # ---------------------------------------- + # Step--4 (main training) + # ---------------------------------------- + ''' + + for epoch in range(1000000): # keep running + for i, train_data in enumerate(train_loader): + + current_step += 1 + + # ------------------------------- + # 1) update learning rate + # ------------------------------- + model.update_learning_rate(current_step) + + # ------------------------------- + # 2) feed patch pairs + # ------------------------------- + model.feed_data(train_data) + + # ------------------------------- + # 3) optimize parameters + # ------------------------------- + model.optimize_parameters(current_step) + + # ------------------------------- + # merge bnorm + # ------------------------------- +# if opt['merge_bn'] and opt['merge_bn_startpoint'] == current_step: +# logger.info('^_^ -----merging bnorm----- ^_^') +# model.merge_bnorm_train() +# model.print_network() + + # ------------------------------- + # 4) training information + # ------------------------------- + if current_step % opt['train']['checkpoint_print'] == 0: + logs = model.current_log() # such as loss + message = ' '.format(epoch, current_step, model.current_learning_rate()) + for k, v in logs.items(): # merge log information into message + message += '{:s}: {:.3e} '.format(k, v) + logger.info(message) + + # ------------------------------- + # 5) save model + # ------------------------------- + if current_step % opt['train']['checkpoint_save'] == 0: + logger.info('Saving the model.') + model.save(current_step) + + # ------------------------------- + # 6) testing + # ------------------------------- + if current_step % opt['train']['checkpoint_test'] == 0: + + avg_psnr = 0.0 + idx = 0 + + for test_data in test_loader: + idx += 1 + image_name_ext = os.path.basename(test_data['L_path'][0]) + img_name, ext = os.path.splitext(image_name_ext) + + img_dir = os.path.join(opt['path']['images'], img_name) + util.mkdir(img_dir) + + model.feed_data(test_data) + model.test() + + visuals = model.current_visuals() + E_img = util.tensor2uint(visuals['E']) + H_img = util.tensor2uint(visuals['H']) + + # ----------------------- + # save estimated image E + # ----------------------- + save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) + util.imsave(E_img, save_img_path) + + # ----------------------- + # calculate PSNR + # ----------------------- + current_psnr = util.calculate_psnr(E_img, H_img, border=border) + + logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr)) + + avg_psnr += current_psnr + + avg_psnr = avg_psnr / idx + + # testing log + logger.info('-->-->-->-->-->-->-->-->-->-->-->-->- + init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') + opt['path']['pretrained_netG'] = init_path_G + current_step = init_iter + + border = 0 + # --<--<--<--<--<--<--<--<--<--<--<--<--<- + + # ---------------------------------------- + # save opt to a '../option.json' file + # ---------------------------------------- + option.save(opt) + + # ---------------------------------------- + # return None for missing key + # ---------------------------------------- + opt = option.dict_to_nonedict(opt) + + # ---------------------------------------- + # configure logger + # ---------------------------------------- + logger_name = 'train' + utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info(option.dict2str(opt)) + + # ---------------------------------------- + # seed + # ---------------------------------------- + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + logger.info('Random seed: {}'.format(seed)) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + ''' + # ---------------------------------------- + # Step--2 (creat dataloader) + # ---------------------------------------- + ''' + + # ---------------------------------------- + # 1) create_dataset + # 2) creat_dataloader for train and test + # ---------------------------------------- + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = define_Dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) + logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) + train_loader = DataLoader(train_set, + batch_size=dataset_opt['dataloader_batch_size'], + shuffle=dataset_opt['dataloader_shuffle'], + num_workers=dataset_opt['dataloader_num_workers'], + drop_last=True, + pin_memory=True) + elif phase == 'test': + test_set = define_Dataset(dataset_opt) + test_loader = DataLoader(test_set, batch_size=1, + shuffle=False, num_workers=1, + drop_last=False, pin_memory=True) + else: + raise NotImplementedError("Phase [%s] is not recognized." % phase) + + ''' + # ---------------------------------------- + # Step--3 (initialize model) + # ---------------------------------------- + ''' + + model = define_Model(opt) + + if opt['merge_bn'] and current_step > opt['merge_bn_startpoint']: + logger.info('^_^ -----merging bnorm----- ^_^') + model.merge_bnorm_train() + + logger.info(model.info_network()) + model.init_train() + logger.info(model.info_params()) + + ''' + # ---------------------------------------- + # Step--4 (main training) + # ---------------------------------------- + ''' + + for epoch in range(1000000): # keep running + for i, train_data in enumerate(train_loader): + + current_step += 1 + + # ------------------------------- + # 1) update learning rate + # ------------------------------- + model.update_learning_rate(current_step) + + # ------------------------------- + # 2) feed patch pairs + # ------------------------------- + model.feed_data(train_data) + + # ------------------------------- + # 3) optimize parameters + # ------------------------------- + model.optimize_parameters(current_step) + + # ------------------------------- + # merge bnorm + # ------------------------------- + if opt['merge_bn'] and opt['merge_bn_startpoint'] == current_step: + logger.info('^_^ -----merging bnorm----- ^_^') + model.merge_bnorm_train() + model.print_network() + + # ------------------------------- + # 4) training information + # ------------------------------- + if current_step % opt['train']['checkpoint_print'] == 0: + logs = model.current_log() # such as loss + message = ' '.format(epoch, current_step, model.current_learning_rate()) + for k, v in logs.items(): # merge log information into message + message += '{:s}: {:.3e} '.format(k, v) + logger.info(message) + + # ------------------------------- + # 5) save model + # ------------------------------- + if current_step % opt['train']['checkpoint_save'] == 0: + logger.info('Saving the model.') + model.save(current_step) + + # ------------------------------- + # 6) testing + # ------------------------------- + if current_step % opt['train']['checkpoint_test'] == 0: + + avg_psnr = 0.0 + idx = 0 + + for test_data in test_loader: + idx += 1 + image_name_ext = os.path.basename(test_data['L_path'][0]) + img_name, ext = os.path.splitext(image_name_ext) + + img_dir = os.path.join(opt['path']['images'], img_name) + util.mkdir(img_dir) + + model.feed_data(test_data) + model.test() + + visuals = model.current_visuals() + E_img = util.tensor2uint(visuals['E']) + H_img = util.tensor2uint(visuals['H']) + + # ----------------------- + # save estimated image E + # ----------------------- + save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) + util.imsave(E_img, save_img_path) + + # ----------------------- + # calculate PSNR + # ----------------------- + current_psnr = util.calculate_psnr(E_img, H_img, border=border) + + logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr)) + + avg_psnr += current_psnr + + avg_psnr = avg_psnr / idx + + # testing log + logger.info('-->-->-->-->-->-->-->-->-->-->-->-->- + init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') + opt['path']['pretrained_netG'] = init_path_G + current_step = init_iter + + border = 0 + # --<--<--<--<--<--<--<--<--<--<--<--<--<- + + # ---------------------------------------- + # save opt to a '../option.json' file + # ---------------------------------------- + option.save(opt) + + # ---------------------------------------- + # return None for missing key + # ---------------------------------------- + opt = option.dict_to_nonedict(opt) + + # ---------------------------------------- + # configure logger + # ---------------------------------------- + logger_name = 'train' + utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info(option.dict2str(opt)) + + # ---------------------------------------- + # seed + # ---------------------------------------- + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + logger.info('Random seed: {}'.format(seed)) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + ''' + # ---------------------------------------- + # Step--2 (creat dataloader) + # ---------------------------------------- + ''' + + # ---------------------------------------- + # 1) create_dataset + # 2) creat_dataloader for train and test + # ---------------------------------------- + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = define_Dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) + logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) + train_loader = DataLoader(train_set, + batch_size=dataset_opt['dataloader_batch_size'], + shuffle=dataset_opt['dataloader_shuffle'], + num_workers=dataset_opt['dataloader_num_workers'], + drop_last=True, + pin_memory=True) + elif phase == 'test': + test_set = define_Dataset(dataset_opt) + test_loader = DataLoader(test_set, batch_size=1, + shuffle=False, num_workers=1, + drop_last=False, pin_memory=True) + else: + raise NotImplementedError("Phase [%s] is not recognized." % phase) + + ''' + # ---------------------------------------- + # Step--3 (initialize model) + # ---------------------------------------- + ''' + + model = define_Model(opt) + + if opt['merge_bn'] and current_step > opt['merge_bn_startpoint']: + logger.info('^_^ -----merging bnorm----- ^_^') + model.merge_bnorm_test() + + logger.info(model.info_network()) + model.init_train() + logger.info(model.info_params()) + + ''' + # ---------------------------------------- + # Step--4 (main training) + # ---------------------------------------- + ''' + + for epoch in range(1000000): # keep running + for i, train_data in enumerate(train_loader): + + current_step += 1 + + # ------------------------------- + # 1) update learning rate + # ------------------------------- + model.update_learning_rate(current_step) + + # ------------------------------- + # 2) feed patch pairs + # ------------------------------- + model.feed_data(train_data) + + # ------------------------------- + # 3) optimize parameters + # ------------------------------- + model.optimize_parameters(current_step) + + # ------------------------------- + # merge bnorm + # ------------------------------- + if opt['merge_bn'] and opt['merge_bn_startpoint'] == current_step: + logger.info('^_^ -----merging bnorm----- ^_^') + model.merge_bnorm_train() + model.print_network() + + # -------------------------- + # 4) training information + # -------------------------- + if current_step % opt['train']['checkpoint_print'] == 0: + logs = model.current_log() # such as loss + message = ' '.format(epoch, current_step, model.current_learning_rate()) + for k, v in logs.items(): # merge log information into message + message += '{:s}: {:.3e} '.format(k, v) + logger.info(message) + + # ------------------------------- + # 5) save model + # ------------------------------- + if current_step % opt['train']['checkpoint_save'] == 0: + logger.info('Saving the model.') + model.save(current_step) + + # ------------------------------- + # 6) testing + # ------------------------------- + if current_step % opt['train']['checkpoint_test'] == 0: + + avg_psnr = 0.0 + idx = 0 + + for test_data in test_loader: + idx += 1 + image_name_ext = os.path.basename(test_data['L_path'][0]) + img_name, ext = os.path.splitext(image_name_ext) + + img_dir = os.path.join(opt['path']['images'], img_name) + util.mkdir(img_dir) + + model.feed_data(test_data) + model.test() + + visuals = model.current_visuals() + E_img = util.tensor2uint(visuals['E']) + H_img = util.tensor2uint(visuals['H']) + + # ----------------------- + # save estimated image E + # ----------------------- + save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) + util.imsave(E_img, save_img_path) + + # ----------------------- + # calculate PSNR + # ----------------------- + current_psnr = util.calculate_psnr(E_img, H_img, border=border) + + logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr)) + + avg_psnr += current_psnr + + avg_psnr = avg_psnr / idx + + # testing log + logger.info('-->-->-->-->-->-->-->-->-->-->-->-->- + init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') + opt['path']['pretrained_netG'] = init_path_G + current_step = init_iter + + border = opt['scale'] + # --<--<--<--<--<--<--<--<--<--<--<--<--<- + + # ---------------------------------------- + # save opt to a '../option.json' file + # ---------------------------------------- + option.save(opt) + + # ---------------------------------------- + # return None for missing key + # ---------------------------------------- + opt = option.dict_to_nonedict(opt) + + # ---------------------------------------- + # configure logger + # ---------------------------------------- + logger_name = 'train' + utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info(option.dict2str(opt)) + + # ---------------------------------------- + # seed + # ---------------------------------------- + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + logger.info('Random seed: {}'.format(seed)) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + ''' + # ---------------------------------------- + # Step--2 (creat dataloader) + # ---------------------------------------- + ''' + + # ---------------------------------------- + # 1) create_dataset + # 2) creat_dataloader for train and test + # ---------------------------------------- + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = define_Dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) + logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) + train_loader = DataLoader(train_set, + batch_size=dataset_opt['dataloader_batch_size'], + shuffle=dataset_opt['dataloader_shuffle'], + num_workers=dataset_opt['dataloader_num_workers'], + drop_last=True, + pin_memory=True) + elif phase == 'test': + test_set = define_Dataset(dataset_opt) + test_loader = DataLoader(test_set, batch_size=1, + shuffle=False, num_workers=1, + drop_last=False, pin_memory=True) + else: + raise NotImplementedError("Phase [%s] is not recognized." % phase) + + ''' + # ---------------------------------------- + # Step--3 (initialize model) + # ---------------------------------------- + ''' + + model = define_Model(opt) + + logger.info(model.info_network()) + model.init_train() + logger.info(model.info_params()) + + ''' + # ---------------------------------------- + # Step--4 (main training) + # ---------------------------------------- + ''' + + for epoch in range(1000000): # keep running + for i, train_data in enumerate(train_loader): + + current_step += 1 + + # ------------------------------- + # 1) update learning rate + # ------------------------------- + model.update_learning_rate(current_step) + + # ------------------------------- + # 2) feed patch pairs + # ------------------------------- + model.feed_data(train_data) + + # ------------------------------- + # 3) optimize parameters + # ------------------------------- + model.optimize_parameters(current_step) + + # ------------------------------- + # 4) training information + # ------------------------------- + if current_step % opt['train']['checkpoint_print'] == 0: + logs = model.current_log() # such as loss + message = ' '.format(epoch, current_step, model.current_learning_rate()) + for k, v in logs.items(): # merge log information into message + message += '{:s}: {:.3e} '.format(k, v) + logger.info(message) + + # ------------------------------- + # 5) save model + # ------------------------------- + if current_step % opt['train']['checkpoint_save'] == 0: + logger.info('Saving the model.') + model.save(current_step) + + # ------------------------------- + # 6) testing + # ------------------------------- + if current_step % opt['train']['checkpoint_test'] == 0: + + avg_psnr = 0.0 + idx = 0 + + for test_data in test_loader: + idx += 1 + image_name_ext = os.path.basename(test_data['L_path'][0]) + img_name, ext = os.path.splitext(image_name_ext) + + img_dir = os.path.join(opt['path']['images'], img_name) + util.mkdir(img_dir) + + model.feed_data(test_data) + model.test() + + visuals = model.current_visuals() + E_img = util.tensor2uint(visuals['E']) + H_img = util.tensor2uint(visuals['H']) + + # ----------------------- + # save estimated image E + # ----------------------- + save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) + util.imsave(E_img, save_img_path) + + # ----------------------- + # calculate PSNR + # ----------------------- + current_psnr = util.calculate_psnr(E_img, H_img, border=border) + + logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr)) + + avg_psnr += current_psnr + + avg_psnr = avg_psnr / idx + + # testing log + logger.info('-->-->-->-->-->-->-->-->-->-->-->-->- + init_iterG, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') + init_iterD, init_path_D = option.find_last_checkpoint(opt['path']['models'], net_type='D') + opt['path']['pretrained_netG'] = init_path_G + opt['path']['pretrained_netD'] = init_path_D + current_step = max(init_iterG, init_iterD) + + # opt['path']['pretrained_netG'] = '' + # current_step = 0 + border = opt['scale'] + # --<--<--<--<--<--<--<--<--<--<--<--<--<- + + # ---------------------------------------- + # save opt to a '../option.json' file + # ---------------------------------------- + option.save(opt) + + # ---------------------------------------- + # return None for missing key + # ---------------------------------------- + opt = option.dict_to_nonedict(opt) + + # ---------------------------------------- + # configure logger + # ---------------------------------------- + logger_name = 'train' + utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info(option.dict2str(opt)) + + # ---------------------------------------- + # seed + # ---------------------------------------- + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + logger.info('Random seed: {}'.format(seed)) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + ''' + # ---------------------------------------- + # Step--2 (creat dataloader) + # ---------------------------------------- + ''' + + # ---------------------------------------- + # 1) create_dataset + # 2) creat_dataloader for train and test + # ---------------------------------------- + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = define_Dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) + logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) + train_loader = DataLoader(train_set, + batch_size=dataset_opt['dataloader_batch_size'], + shuffle=dataset_opt['dataloader_shuffle'], + num_workers=dataset_opt['dataloader_num_workers'], + drop_last=True, + pin_memory=True) + elif phase == 'test': + test_set = define_Dataset(dataset_opt) + test_loader = DataLoader(test_set, batch_size=1, + shuffle=False, num_workers=1, + drop_last=False, pin_memory=True) + else: + raise NotImplementedError("Phase [%s] is not recognized." % phase) + + ''' + # ---------------------------------------- + # Step--3 (initialize model) + # ---------------------------------------- + ''' + + model = define_Model(opt) + + logger.info(model.info_network()) + model.init_train() + logger.info(model.info_params()) + + ''' + # ---------------------------------------- + # Step--4 (main training) + # ---------------------------------------- + ''' + + for epoch in range(1000000): # keep running + for i, train_data in enumerate(train_loader): + + current_step += 1 + + # ------------------------------- + # 1) update learning rate + # ------------------------------- + model.update_learning_rate(current_step) + + # ------------------------------- + # 2) feed patch pairs + # ------------------------------- + model.feed_data(train_data) + + # ------------------------------- + # 3) optimize parameters + # ------------------------------- + model.optimize_parameters(current_step) + + # ------------------------------- + # 4) training information + # ------------------------------- + if current_step % opt['train']['checkpoint_print'] == 0: + logs = model.current_log() # such as loss + message = ' '.format(epoch, current_step, model.current_learning_rate()) + for k, v in logs.items(): # merge log information into message + message += '{:s}: {:.3e} '.format(k, v) + logger.info(message) + + # ------------------------------- + # 5) save model + # ------------------------------- + if current_step % opt['train']['checkpoint_save'] == 0: + logger.info('Saving the model.') + model.save(current_step) + + # ------------------------------- + # 6) testing + # ------------------------------- + if current_step % opt['train']['checkpoint_test'] == 0: + + avg_psnr = 0.0 + idx = 0 + + for test_data in test_loader: + idx += 1 + image_name_ext = os.path.basename(test_data['L_path'][0]) + img_name, ext = os.path.splitext(image_name_ext) + + img_dir = os.path.join(opt['path']['images'], img_name) + util.mkdir(img_dir) + + model.feed_data(test_data) + model.test() + + visuals = model.current_visuals() + E_img = util.tensor2uint(visuals['E']) + H_img = util.tensor2uint(visuals['H']) + + # ----------------------- + # save estimated image E + # ----------------------- + save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) + util.imsave(E_img, save_img_path) + + # ----------------------- + # calculate PSNR + # ----------------------- + current_psnr = util.calculate_psnr(E_img, H_img, border=border) + + logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr)) + + avg_psnr += current_psnr + + avg_psnr = avg_psnr / idx + + # testing log + logger.info('-->-->-->-->-->-->-->-->-->-->-->-->- + init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') + opt['path']['pretrained_netG'] = init_path_G + current_step = init_iter + + border = opt['scale'] + # --<--<--<--<--<--<--<--<--<--<--<--<--<- + + # ---------------------------------------- + # save opt to a '../option.json' file + # ---------------------------------------- + option.save(opt) + + # ---------------------------------------- + # return None for missing key + # ---------------------------------------- + opt = option.dict_to_nonedict(opt) + + # ---------------------------------------- + # configure logger + # ---------------------------------------- + logger_name = 'train' + utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info(option.dict2str(opt)) + + # ---------------------------------------- + # seed + # ---------------------------------------- + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + logger.info('Random seed: {}'.format(seed)) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + ''' + # ---------------------------------------- + # Step--2 (creat dataloader) + # ---------------------------------------- + ''' + + # ---------------------------------------- + # 1) create_dataset + # 2) creat_dataloader for train and test + # ---------------------------------------- + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = define_Dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) + logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) + train_loader = DataLoader(train_set, + batch_size=dataset_opt['dataloader_batch_size'], + shuffle=dataset_opt['dataloader_shuffle'], + num_workers=dataset_opt['dataloader_num_workers'], + drop_last=True, + pin_memory=True) + elif phase == 'test': + test_set = define_Dataset(dataset_opt) + test_loader = DataLoader(test_set, batch_size=1, + shuffle=False, num_workers=1, + drop_last=False, pin_memory=True) + else: + raise NotImplementedError("Phase [%s] is not recognized." % phase) + + ''' + # ---------------------------------------- + # Step--3 (initialize model) + # ---------------------------------------- + ''' + + model = define_Model(opt) + + logger.info(model.info_network()) + model.init_train() + logger.info(model.info_params()) + + ''' + # ---------------------------------------- + # Step--4 (main training) + # ---------------------------------------- + ''' + + for epoch in range(1000000): # keep running + for i, train_data in enumerate(train_loader): + + current_step += 1 + + # ------------------------------- + # 1) update learning rate + # ------------------------------- + model.update_learning_rate(current_step) + + # ------------------------------- + # 2) feed patch pairs + # ------------------------------- + model.feed_data(train_data) + + # ------------------------------- + # 3) optimize parameters + # ------------------------------- + model.optimize_parameters(current_step) + + # ------------------------------- + # 4) training information + # ------------------------------- + if current_step % opt['train']['checkpoint_print'] == 0: + logs = model.current_log() # such as loss + message = ' '.format(epoch, current_step, model.current_learning_rate()) + for k, v in logs.items(): # merge log information into message + message += '{:s}: {:.3e} '.format(k, v) + logger.info(message) + + # ------------------------------- + # 5) save model + # ------------------------------- + if current_step % opt['train']['checkpoint_save'] == 0: + logger.info('Saving the model.') + model.save(current_step) + + # ------------------------------- + # 6) testing + # ------------------------------- + if current_step % opt['train']['checkpoint_test'] == 0: + + avg_psnr = 0.0 + idx = 0 + + for test_data in test_loader: + idx += 1 + image_name_ext = os.path.basename(test_data['L_path'][0]) + img_name, ext = os.path.splitext(image_name_ext) + + img_dir = os.path.join(opt['path']['images'], img_name) + util.mkdir(img_dir) + + model.feed_data(test_data) + model.test() + + visuals = model.current_visuals() + E_img = util.tensor2uint(visuals['E']) + H_img = util.tensor2uint(visuals['H']) + + # ----------------------- + # save estimated image E + # ----------------------- + save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) + util.imsave(E_img, save_img_path) + + # ----------------------- + # calculate PSNR + # ----------------------- + current_psnr = util.calculate_psnr(E_img, H_img, border=border) + + logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr)) + + avg_psnr += current_psnr + + avg_psnr = avg_psnr / idx + + # testing log + logger.info('-->-->-->-->-->-->-->-->-->-->-->-->- + init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') + opt['path']['pretrained_netG'] = init_path_G + current_step = init_iter + + border = opt['scale'] + # --<--<--<--<--<--<--<--<--<--<--<--<--<- + + # ---------------------------------------- + # save opt to a '../option.json' file + # ---------------------------------------- + option.save(opt) + + # ---------------------------------------- + # return None for missing key + # ---------------------------------------- + opt = option.dict_to_nonedict(opt) + + # ---------------------------------------- + # configure logger + # ---------------------------------------- + logger_name = 'train' + utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info(option.dict2str(opt)) + + # ---------------------------------------- + # seed + # ---------------------------------------- + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + logger.info('Random seed: {}'.format(seed)) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + ''' + # ---------------------------------------- + # Step--2 (creat dataloader) + # ---------------------------------------- + ''' + + # ---------------------------------------- + # 1) create_dataset + # 2) creat_dataloader for train and test + # ---------------------------------------- + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = define_Dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) + logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) + train_loader = DataLoader(train_set, + batch_size=dataset_opt['dataloader_batch_size'], + shuffle=dataset_opt['dataloader_shuffle'], + num_workers=dataset_opt['dataloader_num_workers'], + drop_last=True, + pin_memory=True) + elif phase == 'test': + test_set = define_Dataset(dataset_opt) + test_loader = DataLoader(test_set, batch_size=1, + shuffle=False, num_workers=1, + drop_last=False, pin_memory=True) + else: + raise NotImplementedError("Phase [%s] is not recognized." % phase) + + ''' + # ---------------------------------------- + # Step--3 (initialize model) + # ---------------------------------------- + ''' + + model = define_Model(opt) + + logger.info(model.info_network()) + model.init_train() + logger.info(model.info_params()) + + ''' + # ---------------------------------------- + # Step--4 (main training) + # ---------------------------------------- + ''' + + for epoch in range(1000000): # keep running + for i, train_data in enumerate(train_loader): + + current_step += 1 + + # ------------------------------- + # 1) update learning rate + # ------------------------------- + model.update_learning_rate(current_step) + + # ------------------------------- + # 2) feed patch pairs + # ------------------------------- + model.feed_data(train_data) + + # ------------------------------- + # 3) optimize parameters + # ------------------------------- + model.optimize_parameters(current_step) + + # ------------------------------- + # 4) training information + # ------------------------------- + if current_step % opt['train']['checkpoint_print'] == 0: + logs = model.current_log() # such as loss + message = ' '.format(epoch, current_step, model.current_learning_rate()) + for k, v in logs.items(): # merge log information into message + message += '{:s}: {:.3e} '.format(k, v) + logger.info(message) + + # ------------------------------- + # 5) save model + # ------------------------------- + if current_step % opt['train']['checkpoint_save'] == 0: + logger.info('Saving the model.') + model.save(current_step) + + # ------------------------------- + # 6) testing + # ------------------------------- + if current_step % opt['train']['checkpoint_test'] == 0: + + avg_psnr = 0.0 + idx = 0 + + for test_data in test_loader: + idx += 1 + image_name_ext = os.path.basename(test_data['L_path'][0]) + img_name, ext = os.path.splitext(image_name_ext) + + img_dir = os.path.join(opt['path']['images'], img_name) + util.mkdir(img_dir) + + model.feed_data(test_data) + model.test() + + visuals = model.current_visuals() + E_img = util.tensor2uint(visuals['E']) + H_img = util.tensor2uint(visuals['H']) + + # ----------------------- + # save estimated image E + # ----------------------- + save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) + util.imsave(E_img, save_img_path) + + # ----------------------- + # calculate PSNR + # ----------------------- + current_psnr = util.calculate_psnr(E_img, H_img, border=border) + + logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr)) + + avg_psnr += current_psnr + + avg_psnr = avg_psnr / idx + + # testing log + logger.info('-->-->-->-->-->-->-->-->-->-->-->-->- + init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') + opt['path']['pretrained_netG'] = init_path_G + current_step = init_iter + + border = opt['scale'] + # --<--<--<--<--<--<--<--<--<--<--<--<--<- + + # ---------------------------------------- + # save opt to a '../option.json' file + # ---------------------------------------- + option.save(opt) + + # ---------------------------------------- + # return None for missing key + # ---------------------------------------- + opt = option.dict_to_nonedict(opt) + + # ---------------------------------------- + # configure logger + # ---------------------------------------- + logger_name = 'train' + utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info(option.dict2str(opt)) + + # ---------------------------------------- + # calculate PCA projection matrix + # ---------------------------------------- + pca_matrix_path = os.path.join('kernels', 'srmd_pca_pytorch.mat') + if not os.path.exists(pca_matrix_path): + logger.info('calculating PCA projection matrix...') + sisr.cal_pca_matrix(path=pca_matrix_path, ksize=15, l_max=10.0, dim_pca=15, num_samples=5000) + logger.info('done!') + else: + logger.info('loading PCA projection matrix...') + + # ---------------------------------------- + # seed + # ---------------------------------------- + seed = opt['train']['manual_seed'] + if seed is None: + seed = random.randint(1, 10000) + logger.info('Random seed: {}'.format(seed)) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + ''' + # ---------------------------------------- + # Step--2 (creat dataloader) + # ---------------------------------------- + ''' + + # ---------------------------------------- + # 1) create_dataset + # 2) creat_dataloader for train and test + # ---------------------------------------- + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + train_set = define_Dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) + logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) + train_loader = DataLoader(train_set, + batch_size=dataset_opt['dataloader_batch_size'], + shuffle=dataset_opt['dataloader_shuffle'], + num_workers=dataset_opt['dataloader_num_workers'], + drop_last=True, + pin_memory=True) + elif phase == 'test': + test_set = define_Dataset(dataset_opt) + test_loader = DataLoader(test_set, batch_size=1, + shuffle=False, num_workers=1, + drop_last=False, pin_memory=True) + else: + raise NotImplementedError("Phase [%s] is not recognized." % phase) + + ''' + # ---------------------------------------- + # Step--3 (initialize model) + # ---------------------------------------- + ''' + + model = define_Model(opt) + +# if opt['merge_bn'] and current_step > opt['merge_bn_startpoint']: +# logger.info('^_^ -----merging bnorm----- ^_^') +# model.merge_bnorm_test() + + logger.info(model.info_network()) + model.init_train() + logger.info(model.info_params()) + + ''' + # ---------------------------------------- + # Step--4 (main training) + # ---------------------------------------- + ''' + + for epoch in range(1000000): # keep running + for i, train_data in enumerate(train_loader): + + current_step += 1 + + # ------------------------------- + # 1) update learning rate + # ------------------------------- + model.update_learning_rate(current_step) + + # ------------------------------- + # 2) feed patch pairs + # ------------------------------- + model.feed_data(train_data) + + # ------------------------------- + # 3) optimize parameters + # ------------------------------- + model.optimize_parameters(current_step) + + # ------------------------------- + # merge bnorm + # ------------------------------- +# if opt['merge_bn'] and opt['merge_bn_startpoint'] == current_step: +# logger.info('^_^ -----merging bnorm----- ^_^') +# model.merge_bnorm_train() +# model.print_network() + + # ------------------------------- + # 4) training information + # ------------------------------- + if current_step % opt['train']['checkpoint_print'] == 0: + logs = model.current_log() # such as loss + message = ' '.format(epoch, current_step, model.current_learning_rate()) + for k, v in logs.items(): # merge log information into message + message += '{:s}: {:.3e} '.format(k, v) + logger.info(message) + + # ------------------------------- + # 5) save model + # ------------------------------- + if current_step % opt['train']['checkpoint_save'] == 0: + logger.info('Saving the model.') + model.save(current_step) + + # ------------------------------- + # 6) testing + # ------------------------------- + if current_step % opt['train']['checkpoint_test'] == 0: + + avg_psnr = 0.0 + idx = 0 + + for test_data in test_loader: + idx += 1 + image_name_ext = os.path.basename(test_data['L_path'][0]) + img_name, ext = os.path.splitext(image_name_ext) + + img_dir = os.path.join(opt['path']['images'], img_name) + util.mkdir(img_dir) + + model.feed_data(test_data) + model.test() + + visuals = model.current_visuals() + E_img = util.tensor2uint(visuals['E']) + H_img = util.tensor2uint(visuals['H']) + + # ----------------------- + # save estimated image E + # ----------------------- + save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) + util.imsave(E_img, save_img_path) + + # ----------------------- + # calculate PSNR + # ----------------------- + current_psnr = util.calculate_psnr(E_img, H_img, border=border) + + logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr)) + + avg_psnr += current_psnr + + avg_psnr = avg_psnr / idx + + # testing log + logger.info('