Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
cszn authored Dec 15, 2019
1 parent abbc06a commit 1ece71e
Show file tree
Hide file tree
Showing 17 changed files with 3,830 additions and 0 deletions.
202 changes: 202 additions & 0 deletions main_test_dncnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import os.path
import logging

import numpy as np
from datetime import datetime
from collections import OrderedDict
from scipy.io import loadmat

import torch

from utils import utils_logger
from utils import utils_model
from utils import utils_image as util


'''
Spyder (Python 3.6)
PyTorch 1.1.0
Windows 10 or Linux
Kai Zhang ([email protected])
github: https://github.com/cszn/KAIR
https://github.com/cszn/DnCNN
@article{zhang2017beyond,
title={Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising},
author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
journal={IEEE Transactions on Image Processing},
volume={26},
number={7},
pages={3142--3155},
year={2017},
publisher={IEEE}
}
% If you have any question, please feel free to contact with me.
% Kai Zhang (e-mail: [email protected]; github: https://github.com/cszn)
by Kai Zhang (12/Dec./2019)
'''

"""
# --------------------------------------------
|--model_zoo # model_zoo
|--dncnn_15 # model_name
|--dncnn_25
|--dncnn_50
|--dncnn_gray_blind
|--dncnn_color_blind
|--dncnn3
|--testset # testsets
|--set12 # testset_name
|--bsd68
|--cbsd68
|--results # results
|--set12_dncnn_15 # result_name = testset_name + '_' + model_name
|--set12_dncnn_25
|--bsd68_dncnn_15
# --------------------------------------------
"""


def main():

# ----------------------------------------
# Preparation
# ----------------------------------------

noise_level_img = 25 # noise level for noisy image
model_name = 'dncnn_25' # 'dncnn_15' | 'dncnn_25' | 'dncnn_50' | 'dncnn_gray_blind' | 'dncnn_color_blind' | 'dncnn3'
testset_name = 'bsd68' # test set, 'bsd68' | 'set12'
need_degradation = True # default: True
x8 = False # default: False, x8 to boost performance
show_img = False # default: False




task_current = 'dn' # 'dn' for denoising | 'sr' for super-resolution
sf = 1 # unused for denoising
if 'color' in model_name:
n_channels = 3 # fixed, 1 for grayscale image, 3 for color image
else:
n_channels = 1 # fixed for grayscale image
if model_name in ['dncnn_gray_blind', 'dncnn_color_blind', 'dncnn3']:
nb = 20 # fixed
else:
nb = 17 # fixed
model_pool = 'model_zoo' # fixed
testsets = 'testsets' # fixed
results = 'results' # fixed
result_name = testset_name + '_' + model_name # fixed
border = sf if task_current == 'sr' else 0 # shave boader to calculate PSNR and SSIM
model_path = os.path.join(model_pool, model_name+'.pth')

# ----------------------------------------
# L_path, E_path, H_path
# ----------------------------------------

L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
H_path = L_path # H_path, for High-quality images
E_path = os.path.join(results, result_name) # E_path, for Estimated images
util.mkdir(E_path)

if H_path == L_path:
need_degradation = True
logger_name = result_name
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
logger = logging.getLogger(logger_name)

need_H = True if H_path is not None else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ----------------------------------------
# load model
# ----------------------------------------

from models.network_dncnn import DnCNN as net
model = net(in_nc=n_channels, out_nc=n_channels, nc=64, nb=nb, act_mode='R')
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
logger.info('Model path: {:s}'.format(model_path))
number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
logger.info('Params number: {}'.format(number_parameters))

test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []

logger.info('model_name:{}, image sigma:{}'.format(model_name, noise_level_img))
logger.info(L_path)
L_paths = util.get_image_paths(L_path)
H_paths = util.get_image_paths(H_path) if need_H else None

for idx, img in enumerate(L_paths):

# ------------------------------------
# (1) img_L
# ------------------------------------

img_name, ext = os.path.splitext(os.path.basename(img))
# logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
img_L = util.imread_uint(img, n_channels=n_channels)
img_L = util.uint2single(img_L)

if need_degradation: # degradation process
np.random.seed(seed=0) # for reproducibility
img_L += np.random.normal(0, noise_level_img/255., img_L.shape)

util.imshow(util.single2uint(img_L), title='Noisy image with noise level {}'.format(noise_level_img)) if show_img else None

img_L = util.single2tensor4(img_L)
img_L = img_L.to(device)

# ------------------------------------
# (2) img_E
# ------------------------------------

if not x8:
img_E = model(img_L)
else:
img_E = utils_model.test_mode(model, img_L, mode=3)

img_E = util.tensor2uint(img_E)

if need_H:

# --------------------------------
# (3) img_H
# --------------------------------

img_H = util.imread_uint(H_paths[idx], n_channels=n_channels)
img_H = img_H.squeeze()

# --------------------------------
# PSNR and SSIM
# --------------------------------

psnr = util.calculate_psnr(img_E, img_H, border=border)
ssim = util.calculate_ssim(img_E, img_H, border=border)
test_results['psnr'].append(psnr)
test_results['ssim'].append(ssim)
logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim))
util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None

# ------------------------------------
# save results
# ------------------------------------

util.imsave(img_E, os.path.join(E_path, img_name+ext))

if need_H:
ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
logger.info('Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim))

if __name__ == '__main__':

main()
Loading

0 comments on commit 1ece71e

Please sign in to comment.