diff --git a/README.md b/README.md new file mode 100644 index 0000000..335ac24 --- /dev/null +++ b/README.md @@ -0,0 +1,112 @@ +## Towards Flexible Blind JPEG Artifacts Removal (ICCV 2021) (PyTorch) + +Jiaxi Jiang, Kai Zhang, Radu Timofte + +Computer Vision Lab, ETH Zurich, Switzerland + +________ +✨ _**Some visual examples (click the images for more details)**_: + +[](https://imgsli.com/NzA3NjI) +[](https://imgsli.com/NzA3NTk) +[](https://imgsli.com/NzA3NjE) +[](https://imgsli.com/NzA3NjM) +[](https://imgsli.com/NzA3NjQ) +[](https://imgsli.com/NzA3NjU) + +________ + +### 1. Motivations +JPEG is one of the most widely-used image compression algorithms and formats due to its simplicity and fast encoding/decoding speeds. However, it is a lossy compression algorithm and can introduce annoying artifacts. Existing methods for JPEG artifacts removal generally have four limitations in real applications: + +a. Most existing learning-based methods [e.g. ARCNN, MWCNN, SwinIR] trained a specific model for each quality factor, lacking the flexibility to learn a single model for different JPEG quality factors. + +b. DCT-based methods [e.g. DMCNN, QGAC] need to obtain the DCT coefficients or quantization table as input, which is only stored in JPEG format. Besides, when images are compressed multiple times, only the most recent compression information is stored. + +c. Existing blind methods [e.g. DnCNN, DCSC, QGAC] can only provide a deterministic reconstruction result for each input, ignoring the need for user preferences. + +d. Existing methods are all trained with synthetic images which assumes that the low-quality images are compressed only once. However, most images from the Internet are compressed multiple times. Despite some progress for real recompressed images, e.g. from Twitter [ARCNN, DCSC], a detailed and complete study on double JPEG artifacts removal is still missing. + + +### 2. Network Architecture +We propose a flexible blind convolutional neural network (FBCNN) that predicts the quality factor of a JPEG image and embed it into the decoder to guide image restoration. The quality factor can be manually adjusted for flexible JPEG restoration according to the user's preference. +![architecture](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/architecture.png) + +### 3. Consideration on Restoration of Double JPEG Restoration +#### a. Limitation of Existing Blind Methods +We find that existing blind methods always do not work when the 8x8 blocks of two JPEG compression are not aligned and QF1 <= QF2, _**even with just a one-pixel shift.**_ Other cases such as non-aligned double JPEG with QF1>QF2, or aligned double JPEG compression, are actually equivalent to single JPEG compression. + +Here is an example of the restoration result of DnCNN and QGAC on a JPEG image with different degradation settings. '*' means there is a one-pixel shift between two JPEG blocks. +![lena_doublejpeg](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/lena_doublejpeg.png) + + +#### b. Our Solutions +We find for non-aligned double JPEG images with QF1 < QF2, FBCNN always predicts the quality factor as QF2. However, it is the smaller QF1 that dominants the compression artifacts. By manually changing the predicted quality factor to QF1, we largely improve the result. + +Besides, to get a fully blind model, we propose two blind solutions to solve this problem: + +(1) FBCNN-D: Train a model with a single JPEG degradation model + automatic dominant QF correction. By utilizing the property of JPEG images, we find the quality factor of a single JPEG image can be predicted by applying another JPEG compression. When QF1 = QF2, the MSE of two JPEG images is minimal. In our paper, we also extend this method to non-aligned double JPEG cases to get a fully blind model. + +(2) FBCNN-A: Augment training data with double JPEG degradation model, which is given by: + +
+ y = JPEG(shift(JPEG(x, QF1)),QF2) +
+ +By reducing the misalignment of training data and real-world JPEG images, FBCNN-A further improves the results on complex double JPEG restoration. _**This proposed double JPEG degradation model can be easily integrated into other image restoration tasks, such as Single Image Super-Resolution, for better general real image restoration.**_ + +To the best of our knowledge, we are the first to tackle the problem of restoration of non-aligned double JPEG compression. As JPEG is the most widely used image compression algorithm and format, and most real-world JPEG images are compressed many times, we believe it would be a significant step towards real image restoration. + +### 3. Experiments + +#### a. Single JPEG Restoration +![single_table](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/single_table.png) +*: Train a specific model for each quality factor. +![single_compare](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/single_compare.png) + +#### b. Non-Aligned Double JPEG Restoration +There is a pixel shift of (4,4) between the blocks of two JPEG compression. +![double_table](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/double_table.png) +![double_compare](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/double_compare.png) + +#### c. Real JPEG Restoration +![real](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/real.png) + +#### d. Flexibility of FBCNN +By setting different quality factors, we can control the trade-off between artifacts removal and details preservation. +![flexible](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/flexible.png) +### 4. Training +We will release the training code at [KAIR](https://github.com/cszn/KAIR/). + +### 5. Testing +#### a. Grayscale Images (Calculate Metrics) +Put the folder with uncompressed grayscale or Y channel images (Classic5, LIVE1, BSDS500, etc.) under `testsets`. This code generates compressed JPEG images and calculates PSNR, SSIM, PSNRB. + +###### Single JPEG Restoration +```bash +python main_test_fbcnn_gray.py +``` +###### Double JPEG Restoration +```bash +python main_test_fbcnn_gray_doublejpeg.py +``` + +#### b. Color Images (Calculate Metrics) +Put the folder with uncompressed images (LIVE1, BSDS500, ICB, etc.) under `testsets`. This code generates compressed JPEG images and calculates PSNR, SSIM, PSNRB. + +```bash +python main_test_fbcnn_color.py +``` +#### c. Real-World Images (Real Application) +Put the folder with real-world compressed images under `testsets`. This code restores the images without calculating metrics. Please note by setting different quality factors, we can control the trade-off between artifacts removal and details preservation. +```bash +python main_test_fbcnn_color_real.py +``` + +### Citation + @inproceedings{jiang2021towards, + title={Towards Flexible Blind {JPEG} Artifacts Removal}, + author={Jiang, Jiaxi and Zhang, Kai and Timofte, Radu}, + booktitle={IEEE International Conference on Computer Vision}, + year={2021} + } diff --git a/figs/architecture.png b/figs/architecture.png new file mode 100644 index 0000000..fd55eb3 Binary files /dev/null and b/figs/architecture.png differ diff --git a/figs/double_compare.png b/figs/double_compare.png new file mode 100644 index 0000000..d66d136 Binary files /dev/null and b/figs/double_compare.png differ diff --git a/figs/double_table.png b/figs/double_table.png new file mode 100644 index 0000000..ac3027d Binary files /dev/null and b/figs/double_table.png differ diff --git a/figs/flexible.png b/figs/flexible.png new file mode 100644 index 0000000..8584c29 Binary files /dev/null and b/figs/flexible.png differ diff --git a/figs/lena_doublejpeg.png b/figs/lena_doublejpeg.png new file mode 100644 index 0000000..72f9a07 Binary files /dev/null and b/figs/lena_doublejpeg.png differ diff --git a/figs/real.png b/figs/real.png new file mode 100644 index 0000000..e450ae1 Binary files /dev/null and b/figs/real.png differ diff --git a/figs/single_compare.png b/figs/single_compare.png new file mode 100644 index 0000000..b929d28 Binary files /dev/null and b/figs/single_compare.png differ diff --git a/figs/single_table.png b/figs/single_table.png new file mode 100644 index 0000000..359f11a Binary files /dev/null and b/figs/single_table.png differ diff --git a/figs/v1.png b/figs/v1.png new file mode 100644 index 0000000..2e9c911 Binary files /dev/null and b/figs/v1.png differ diff --git a/figs/v2.png b/figs/v2.png new file mode 100644 index 0000000..8679d6e Binary files /dev/null and b/figs/v2.png differ diff --git a/figs/v3.png b/figs/v3.png new file mode 100644 index 0000000..3d982cd Binary files /dev/null and b/figs/v3.png differ diff --git a/figs/v4.png b/figs/v4.png new file mode 100644 index 0000000..39b0c5d Binary files /dev/null and b/figs/v4.png differ diff --git a/figs/v5.png b/figs/v5.png new file mode 100644 index 0000000..7d6c1cf Binary files /dev/null and b/figs/v5.png differ diff --git a/figs/v6.png b/figs/v6.png new file mode 100644 index 0000000..cfdb764 Binary files /dev/null and b/figs/v6.png differ diff --git a/main_test_fbcnn_color.py b/main_test_fbcnn_color.py new file mode 100644 index 0000000..2c48171 --- /dev/null +++ b/main_test_fbcnn_color.py @@ -0,0 +1,112 @@ +import os.path +import logging +import numpy as np +from datetime import datetime +from collections import OrderedDict +import torch +import cv2 +from utils import utils_logger +from utils import utils_image as util + +def main(): + + quality_factor_list = [10, 20, 30, 40, 50, 60, 70, 80, 90] + testset_name = 'LIVE1_color' # 'LIVE1_color' 'BSDS500_color' 'ICB' + n_channels = 3 # set 1 for grayscale image, set 3 for color image + model_name = 'fbcnn_color.pth' + nc = [64,128,256,512] + nb = 4 + show_img = False # default: False + testsets = 'testsets' + results = 'test_results' + + for quality_factor in quality_factor_list: + + result_name = testset_name + '_' + model_name[:-4] + H_path = os.path.join(testsets, testset_name) + E_path = os.path.join(results, result_name, str(quality_factor)) # E_path, for Estimated images + util.mkdir(E_path) + + model_pool = 'model_zoo' # fixed + model_path = os.path.join(model_pool, model_name) + logger_name = result_name + '_qf_' + str(quality_factor) + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info('--------------- quality factor: {:d} ---------------'.format(quality_factor)) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + border = 0 + + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_fbcnn import FBCNN as net + model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R') + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnrb'] = [] + + H_paths = util.get_image_paths(H_path) + for idx, img in enumerate(H_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + img_name, ext = os.path.splitext(os.path.basename(img)) + logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + + if n_channels == 3: + img_L = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR) + _, encimg = cv2.imencode('.jpg', img_L, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img_L = cv2.imdecode(encimg, 0) if n_channels == 1 else cv2.imdecode(encimg, 3) + if n_channels == 3: + img_L = cv2.cvtColor(img_L, cv2.COLOR_BGR2RGB) + img_L = util.uint2tensor4(img_L) + img_L = img_L.to(device) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + #img_E,QF = model(img_L, torch.tensor([[0.6]])) + img_E,QF = model(img_L) + QF = 1 - QF + img_E = util.tensor2single(img_E) + img_E = util.single2uint(img_E) + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels).squeeze() + # -------------------------------- + # PSNR and SSIM, PSNRB + # -------------------------------- + + psnr = util.calculate_psnr(img_E, img_H, border=border) + ssim = util.calculate_ssim(img_E, img_H, border=border) + psnrb = util.calculate_psnrb(img_H, img_E, border=border) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + test_results['psnrb'].append(psnrb) + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.3f}; PSNRB: {:.2f} dB.'.format(img_name+ext, psnr, ssim, psnrb)) + logger.info('predicted quality factor: {:d}'.format(round(float(QF*100)))) + + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None + util.imsave(img_E, os.path.join(E_path, img_name+'.png')) + + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + ave_psnrb = sum(test_results['psnrb']) / len(test_results['psnrb']) + logger.info( + 'Average PSNR/SSIM/PSNRB - {} -: {:.2f}$\\vert${:.4f}$\\vert${:.2f}.'.format(result_name+'_'+str(quality_factor), ave_psnr, ave_ssim, ave_psnrb)) + + +if __name__ == '__main__': + main() diff --git a/main_test_fbcnn_color_real.py b/main_test_fbcnn_color_real.py new file mode 100644 index 0000000..6bb5567 --- /dev/null +++ b/main_test_fbcnn_color_real.py @@ -0,0 +1,95 @@ +import os.path +import logging +import numpy as np +from datetime import datetime +from collections import OrderedDict +import torch +import cv2 +from utils import utils_logger +from utils import utils_image as util + +def main(): + + testset_name = 'Real' # folder name of real images + n_channels = 3 # set 1 for grayscale image, set 3 for color image + model_name = 'fbcnn_color.pth' + nc = [64,128,256,512] + nb = 4 + testsets = 'testsets' + results = 'test_results' + + do_flexible_control = True + QF_control = [10,30,50,70,90] # adjust qf as input to provide different results + + result_name = testset_name + '_' + model_name[:-4] + L_path = os.path.join(testsets, testset_name) + E_path = os.path.join(results, result_name) # E_path, for Estimated images + util.mkdir(E_path) + + model_pool = 'model_zoo' # fixed + model_path = os.path.join(model_pool, model_name) + logger_name = result_name + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + border = 0 + + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_fbcnn import FBCNN as net + model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='BR') + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnrb'] = [] + + L_paths = util.get_image_paths(L_path) + for idx, img in enumerate(L_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + img_name, ext = os.path.splitext(os.path.basename(img)) + logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + + img_L = util.uint2tensor4(img_L) + img_L = img_L.to(device) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + #img_E,QF = model(img_L, torch.tensor([[0.6]])) + img_E,QF = model(img_L) + QF = 1- QF + img_E = util.tensor2single(img_E) + img_E = util.single2uint(img_E) + logger.info('predicted quality factor: {:d}'.format(round(float(QF*100)))) + util.imsave(img_E, os.path.join(E_path, img_name+'.png')) + + if do_flexible_control: + for QF_set in QF_control: + logger.info('Flexible control by QF = {:d}'.format(QF_set)) + # from IPython import embed; embed() + qf_input = torch.tensor([[1-QF_set/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-QF_set/100]]) + img_E,QF = model(img_L, qf_input) + QF = 1- QF + img_E = util.tensor2single(img_E) + img_E = util.single2uint(img_E) + util.imsave(img_E, os.path.join(E_path, img_name + '_qf_'+ str(QF_set)+'.png')) + + +if __name__ == '__main__': + main() diff --git a/main_test_fbcnn_gray.py b/main_test_fbcnn_gray.py new file mode 100644 index 0000000..647cc6b --- /dev/null +++ b/main_test_fbcnn_gray.py @@ -0,0 +1,112 @@ +import os.path +import logging +import numpy as np +from datetime import datetime +from collections import OrderedDict +import torch +import cv2 +from utils import utils_logger +from utils import utils_image as util + +def main(): + + quality_factor_list = [10, 20, 30, 40, 50, 60, 70, 80, 90] + testset_name = 'Classic5' # 'LIVE1_gray' 'Classic5' 'BSDS500_gray' + n_channels = 1 # set 1 for grayscale image, set 3 for color image + model_name = 'fbcnn_gray.pth' + nc = [64,128,256,512] + nb = 4 + show_img = False # default: False + testsets = 'testsets' + results = 'test_results' + + for quality_factor in quality_factor_list: + result_name = testset_name + '_' + model_name[:-4] + H_path = os.path.join(testsets, testset_name) + E_path = os.path.join(results, result_name, str(quality_factor)) # E_path, for Estimated images + util.mkdir(E_path) + + model_pool = 'model_zoo' # fixed + model_path = os.path.join(model_pool, model_name) + logger_name = result_name + '_qf_' + str(quality_factor) + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info('--------------- quality factor: {:d} ---------------'.format(quality_factor)) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + border = 0 + + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_fbcnn import FBCNN as net + model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R') + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnrb'] = [] + + H_paths = util.get_image_paths(H_path) + for idx, img in enumerate(H_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + img_name, ext = os.path.splitext(os.path.basename(img)) + logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + + if n_channels == 3: + img_L = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR) + _, encimg = cv2.imencode('.jpg', img_L, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img_L = cv2.imdecode(encimg, 0) if n_channels == 1 else cv2.imdecode(encimg, 3) + if n_channels == 3: + img_L = cv2.cvtColor(img_L, cv2.COLOR_BGR2RGB) + img_L = util.uint2tensor4(img_L) + img_L = img_L.to(device) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + # img_E,QF = model(img_L, torch.tensor([[0.6]])) + img_E,QF = model(img_L) + QF = 1 - QF + + img_E = util.tensor2single(img_E) + img_E = util.single2uint(img_E) + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels).squeeze() + # -------------------------------- + # PSNR and SSIM, PSNRB + # -------------------------------- + + psnr = util.calculate_psnr(img_E, img_H, border=border) + ssim = util.calculate_ssim(img_E, img_H, border=border) + psnrb = util.calculate_psnrb(img_H, img_E, border=border) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + test_results['psnrb'].append(psnrb) + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.3f}; PSNRB: {:.2f} dB.'.format(img_name+ext, psnr, ssim, psnrb)) + logger.info('predicted quality factor: {:d}'.format(round(float(QF*100)))) + + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None + util.imsave(img_E, os.path.join(E_path, img_name+'.png')) + + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + ave_psnrb = sum(test_results['psnrb']) / len(test_results['psnrb']) + logger.info( + 'Average PSNR/SSIM/PSNRB - {} -: {:.2f}$\\vert${:.4f}$\\vert${:.2f}.'.format(result_name+'_'+str(quality_factor), ave_psnr, ave_ssim, ave_psnrb)) + + +if __name__ == '__main__': + main() diff --git a/main_test_fbcnn_gray_doublejpeg.py b/main_test_fbcnn_gray_doublejpeg.py new file mode 100644 index 0000000..30f0548 --- /dev/null +++ b/main_test_fbcnn_gray_doublejpeg.py @@ -0,0 +1,118 @@ +import os.path +import logging +import numpy as np +from datetime import datetime +from collections import OrderedDict +import torch +import cv2 +from utils import utils_logger +from utils import utils_image as util + +def main(): + + quality_factor_list = [10, 30, 50] + testset_name = 'LIVE1_gray' # 'LIVE1_gray' 'Classic5' 'BSDS500_gray' + n_channels = 1 # set 1 for grayscale image, set 3 for color image + model_name = 'fbcnn_gray_double.pth' + nc = [64,128,256,512] + nb = 4 + show_img = False # default: False + testsets = 'testsets' + results = 'test_results' + + for qf1 in quality_factor_list: + for qf2 in quality_factor_list: + result_name = testset_name + '_' + model_name[:-4] + H_path = os.path.join(testsets, testset_name) + E_path = os.path.join(results, result_name, str(qf1)+str(qf2)) # E_path, for Estimated images + util.mkdir(E_path) + + model_pool = 'model_zoo' # fixed + model_path = os.path.join(model_pool, model_name) + logger_name = result_name + '_qf_' + str(qf1)+str(qf2) + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) + logger = logging.getLogger(logger_name) + logger.info('--------------- QF1={:d}, QF2={:d} ---------------'.format(qf1,qf2)) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + border = 0 + + + # ---------------------------------------- + # load model + # ---------------------------------------- + + from models.network_fbcnn import FBCNN as net + model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R') + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + logger.info('Model path: {:s}'.format(model_path)) + + test_results = OrderedDict() + test_results['psnr'] = [] + test_results['ssim'] = [] + test_results['psnrb'] = [] + + H_paths = util.get_image_paths(H_path) + for idx, img in enumerate(H_paths): + + # ------------------------------------ + # (1) img_L + # ------------------------------------ + img_name, ext = os.path.splitext(os.path.basename(img)) + logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) + img_L = util.imread_uint(img, n_channels=n_channels) + + if n_channels == 3: + img_L = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR) + _, encimg = cv2.imencode('.jpg', img_L, [int(cv2.IMWRITE_JPEG_QUALITY), qf1]) + img_L = cv2.imdecode(encimg, 0) if n_channels == 1 else cv2.imdecode(encimg, 3) + + shift_h, shift_w = 4, 4 + _, encimg = cv2.imencode('.jpg', img_L[shift_h:,shift_w:], [int(cv2.IMWRITE_JPEG_QUALITY), qf2]) + img_L = cv2.imdecode(encimg, 0) if n_channels == 1 else cv2.imdecode(encimg, 3) + + if n_channels == 3: + img_L = cv2.cvtColor(img_L, cv2.COLOR_BGR2RGB) + img_L = util.uint2tensor4(img_L) + img_L = img_L.to(device) + + # ------------------------------------ + # (2) img_E + # ------------------------------------ + + # img_E,QF = model(img_L, torch.tensor([[0.6]])) + img_E,QF = model(img_L) + QF = 1 - QF + + img_E = util.tensor2single(img_E) + img_E = util.single2uint(img_E) + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels).squeeze()[shift_h:,shift_w:] + # -------------------------------- + # PSNR and SSIM, PSNRB + # -------------------------------- + + psnr = util.calculate_psnr(img_E, img_H, border=border) + ssim = util.calculate_ssim(img_E, img_H, border=border) + psnrb = util.calculate_psnrb(img_H, img_E, border=border) + test_results['psnr'].append(psnr) + test_results['ssim'].append(ssim) + test_results['psnrb'].append(psnrb) + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.3f}; PSNRB: {:.2f} dB.'.format(img_name+ext, psnr, ssim, psnrb)) + logger.info('predicted quality factor: {:d}'.format(round(float(QF*100)))) + + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None + util.imsave(img_E, os.path.join(E_path, img_name+'.png')) + + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) + ave_psnrb = sum(test_results['psnrb']) / len(test_results['psnrb']) + logger.info( + 'Average PSNR/SSIM/PSNRB - {} -: {:.2f}$\\vert${:.4f}$\\vert${:.2f}.'.format(result_name+'_qf1_'+str(qf1)+'_qf2_'+str(qf2), ave_psnr, ave_ssim, ave_psnrb)) + + +if __name__ == '__main__': + main() diff --git a/model_zoo/README.md b/model_zoo/README.md new file mode 100644 index 0000000..900b707 --- /dev/null +++ b/model_zoo/README.md @@ -0,0 +1,5 @@ + +* Download the following models from [Google drive](https://drive.google.com/drive/folders/18k-_hcLUL8HuIOWb0OGc4ki959tvqhwr?usp=sharing) or [腾讯微云](https://share.weiyun.com/20CmUkqF). + * fbcnn_color.pth + * fbcnn_gray.pth + * fbcnn_gray_double.pth diff --git a/models/__pycache__/network_fbcnn.cpython-37.pyc b/models/__pycache__/network_fbcnn.cpython-37.pyc new file mode 100644 index 0000000..ae46c54 Binary files /dev/null and b/models/__pycache__/network_fbcnn.cpython-37.pyc differ diff --git a/models/network_fbcnn.py b/models/network_fbcnn.py new file mode 100644 index 0000000..4f62b47 --- /dev/null +++ b/models/network_fbcnn.py @@ -0,0 +1,337 @@ +from collections import OrderedDict +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import torchvision.models as models + +''' +# -------------------------------------------- +# Advanced nn.Sequential +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +def sequential(*args): + """Advanced nn.Sequential. + + Args: + nn.Sequential, nn.Module + + Returns: + nn.Sequential + """ + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError('sequential does not support OrderedDict input.') + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + +# -------------------------------------------- +# return nn.Sequantial of (Conv + BN + ReLU) +# -------------------------------------------- +def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', negative_slope=0.2): + L = [] + for t in mode: + if t == 'C': + L.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) + elif t == 'T': + L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) + elif t == 'B': + L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) + elif t == 'I': + L.append(nn.InstanceNorm2d(out_channels, affine=True)) + elif t == 'R': + L.append(nn.ReLU(inplace=True)) + elif t == 'r': + L.append(nn.ReLU(inplace=False)) + elif t == 'L': + L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) + elif t == 'l': + L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) + elif t == '2': + L.append(nn.PixelShuffle(upscale_factor=2)) + elif t == '3': + L.append(nn.PixelShuffle(upscale_factor=3)) + elif t == '4': + L.append(nn.PixelShuffle(upscale_factor=4)) + elif t == 'U': + L.append(nn.Upsample(scale_factor=2, mode='nearest')) + elif t == 'u': + L.append(nn.Upsample(scale_factor=3, mode='nearest')) + elif t == 'v': + L.append(nn.Upsample(scale_factor=4, mode='nearest')) + elif t == 'M': + L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) + elif t == 'A': + L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) + else: + raise NotImplementedError('Undefined type: '.format(t)) + return sequential(*L) + +# -------------------------------------------- +# Res Block: x + conv(relu(conv(x))) +# -------------------------------------------- +class ResBlock(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2): + super(ResBlock, self).__init__() + + assert in_channels == out_channels, 'Only support in_channels==out_channels.' + if mode[0] in ['R', 'L']: + mode = mode[0].lower() + mode[1:] + + self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) + + def forward(self, x): + res = self.res(x) + return x + res + +# -------------------------------------------- +# conv + subp (+ relu) +# -------------------------------------------- +def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' + up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode='C'+mode, negative_slope=negative_slope) + return up1 + + +# -------------------------------------------- +# nearest_upsample + conv (+ R) +# -------------------------------------------- +def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR' + if mode[0] == '2': + uc = 'UC' + elif mode[0] == '3': + uc = 'uC' + elif mode[0] == '4': + uc = 'vC' + mode = mode.replace(mode[0], uc) + up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope) + return up1 + + +# -------------------------------------------- +# convTranspose (+ relu) +# -------------------------------------------- +def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' + kernel_size = int(mode[0]) + stride = int(mode[0]) + mode = mode.replace(mode[0], 'T') + up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) + return up1 + + +''' +# -------------------------------------------- +# Downsampler +# Kai Zhang, https://github.com/cszn/KAIR +# -------------------------------------------- +# downsample_strideconv +# downsample_maxpool +# downsample_avgpool +# -------------------------------------------- +''' + + +# -------------------------------------------- +# strideconv (+ relu) +# -------------------------------------------- +def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' + kernel_size = int(mode[0]) + stride = int(mode[0]) + mode = mode.replace(mode[0], 'C') + down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) + return down1 + + +# -------------------------------------------- +# maxpooling + conv (+ relu) +# -------------------------------------------- +def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' + kernel_size_pool = int(mode[0]) + stride_pool = int(mode[0]) + mode = mode.replace(mode[0], 'MC') + pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) + pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) + return sequential(pool, pool_tail) + + +# -------------------------------------------- +# averagepooling + conv (+ relu) +# -------------------------------------------- +def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' + kernel_size_pool = int(mode[0]) + stride_pool = int(mode[0]) + mode = mode.replace(mode[0], 'AC') + pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) + pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) + return sequential(pool, pool_tail) + + + +class QFAttention(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2): + super(QFAttention, self).__init__() + + assert in_channels == out_channels, 'Only support in_channels==out_channels.' + if mode[0] in ['R', 'L']: + mode = mode[0].lower() + mode[1:] + + self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) + + def forward(self, x, gamma, beta): + gamma = gamma.unsqueeze(-1).unsqueeze(-1) + beta = beta.unsqueeze(-1).unsqueeze(-1) + res = (gamma)*self.res(x) + beta + return x + res + + +class FBCNN(nn.Module): + def __init__(self, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', + upsample_mode='convtranspose'): + super(FBCNN, self).__init__() + + self.m_head = conv(in_nc, nc[0], bias=True, mode='C') + self.nb = nb + self.nc = nc + # downsample + if downsample_mode == 'avgpool': + downsample_block = downsample_avgpool + elif downsample_mode == 'maxpool': + downsample_block = downsample_maxpool + elif downsample_mode == 'strideconv': + downsample_block = downsample_strideconv + else: + raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) + + self.m_down1 = sequential( + *[ResBlock(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)], + downsample_block(nc[0], nc[1], bias=True, mode='2')) + self.m_down2 = sequential( + *[ResBlock(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)], + downsample_block(nc[1], nc[2], bias=True, mode='2')) + self.m_down3 = sequential( + *[ResBlock(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)], + downsample_block(nc[2], nc[3], bias=True, mode='2')) + + self.m_body_encoder = sequential( + *[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]) + + self.m_body_decoder = sequential( + *[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]) + + # upsample + if upsample_mode == 'upconv': + upsample_block = upsample_upconv + elif upsample_mode == 'pixelshuffle': + upsample_block = upsample_pixelshuffle + elif upsample_mode == 'convtranspose': + upsample_block = upsample_convtranspose + else: + raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + + self.m_up3 = nn.ModuleList([upsample_block(nc[3], nc[2], bias=True, mode='2'), + *[QFAttention(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]]) + + self.m_up2 = nn.ModuleList([upsample_block(nc[2], nc[1], bias=True, mode='2'), + *[QFAttention(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]]) + + self.m_up1 = nn.ModuleList([upsample_block(nc[1], nc[0], bias=True, mode='2'), + *[QFAttention(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]]) + + + self.m_tail = conv(nc[0], out_nc, bias=True, mode='C') + + + self.qf_pred = sequential(*[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)], + torch.nn.AdaptiveAvgPool2d((1,1)), + torch.nn.Flatten(), + torch.nn.Linear(512, 512), + nn.ReLU(), + torch.nn.Linear(512, 512), + nn.ReLU(), + torch.nn.Linear(512, 1), + nn.Sigmoid() + ) + + self.qf_embed = sequential(torch.nn.Linear(1, 512), + nn.ReLU(), + torch.nn.Linear(512, 512), + nn.ReLU(), + torch.nn.Linear(512, 512), + nn.ReLU() + ) + + self.to_gamma_3 = sequential(torch.nn.Linear(512, nc[2]),nn.Sigmoid()) + self.to_beta_3 = sequential(torch.nn.Linear(512, nc[2]),nn.Tanh()) + self.to_gamma_2 = sequential(torch.nn.Linear(512, nc[1]),nn.Sigmoid()) + self.to_beta_2 = sequential(torch.nn.Linear(512, nc[1]),nn.Tanh()) + self.to_gamma_1 = sequential(torch.nn.Linear(512, nc[0]),nn.Sigmoid()) + self.to_beta_1 = sequential(torch.nn.Linear(512, nc[0]),nn.Tanh()) + + + def forward(self, x, qf_input=None): + + h, w = x.size()[-2:] + paddingBottom = int(np.ceil(h / 8) * 8 - h) + paddingRight = int(np.ceil(w / 8) * 8 - w) + x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x) + + x1 = self.m_head(x) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body_encoder(x4) + qf = self.qf_pred(x) + x = self.m_body_decoder(x) + qf_embedding = self.qf_embed(qf_input) if qf_input is not None else self.qf_embed(qf) + gamma_3 = self.to_gamma_3(qf_embedding) + beta_3 = self.to_beta_3(qf_embedding) + + gamma_2 = self.to_gamma_2(qf_embedding) + beta_2 = self.to_beta_2(qf_embedding) + + gamma_1 = self.to_gamma_1(qf_embedding) + beta_1 = self.to_beta_1(qf_embedding) + + + x = x + x4 + x = self.m_up3[0](x) + for i in range(self.nb): + x = self.m_up3[i+1](x, gamma_3,beta_3) + + x = x + x3 + + x = self.m_up2[0](x) + for i in range(self.nb): + x = self.m_up2[i+1](x, gamma_2, beta_2) + x = x + x2 + + x = self.m_up1[0](x) + for i in range(self.nb): + x = self.m_up1[i+1](x, gamma_1, beta_1) + + x = x + x1 + x = self.m_tail(x) + x = x[..., :h, :w] + + return x, qf + +if __name__ == "__main__": + x = torch.randn(1, 3, 96, 96)#.cuda()#.to(torch.device('cuda')) + fbar=FBAR() + y,qf = fbar(x) + print(y.shape,qf.shape) diff --git a/test_results/LIVE1_color_fbcnn_color/10/LIVE1_color_fbcnn_color_qf_10.log b/test_results/LIVE1_color_fbcnn_color/10/LIVE1_color_fbcnn_color_qf_10.log new file mode 100644 index 0000000..fcf41c9 --- /dev/null +++ b/test_results/LIVE1_color_fbcnn_color/10/LIVE1_color_fbcnn_color_qf_10.log @@ -0,0 +1,21 @@ +21-09-13 03:44:52.951 : --------------- quality factor: 10 --------------- +21-09-13 03:45:12.512 : Model path: model_zoo/fbcnn_color.pth +21-09-13 03:45:12.518 : ---1--> bikes.bmp +21-09-13 03:45:52.038 : --------------- quality factor: 10 --------------- +21-09-13 03:45:53.114 : Model path: model_zoo/fbcnn_color.pth +21-09-13 03:45:53.115 : ---1--> bikes.bmp +21-09-13 03:46:08.818 : bikes.bmp - PSNR: 26.39 dB; SSIM: 0.802; PSNRB: 26.12 dB. +21-09-13 03:46:08.822 : predicted quality factor: 10 +21-09-13 03:46:08.862 : ---2--> building2.bmp +21-09-13 03:46:22.249 : building2.bmp - PSNR: 23.38 dB; SSIM: 0.759; PSNRB: 22.88 dB. +21-09-13 03:46:22.249 : predicted quality factor: 10 +21-09-13 03:46:22.281 : ---3--> buildings.bmp +21-09-13 03:46:39.070 : buildings.bmp - PSNR: 25.99 dB; SSIM: 0.814; PSNRB: 25.51 dB. +21-09-13 03:46:39.070 : predicted quality factor: 10 +21-09-13 03:46:39.111 : ---4--> caps.bmp +21-09-13 03:46:56.712 : caps.bmp - PSNR: 31.66 dB; SSIM: 0.867; PSNRB: 31.62 dB. +21-09-13 03:46:56.712 : predicted quality factor: 10 +21-09-13 03:46:56.750 : ---5--> carnivaldolls.bmp +21-09-13 03:47:10.286 : carnivaldolls.bmp - PSNR: 27.85 dB; SSIM: 0.853; PSNRB: 27.85 dB. +21-09-13 03:47:10.286 : predicted quality factor: 10 +21-09-13 03:47:10.318 : ---6--> cemetry.bmp diff --git a/test_results/LIVE1_color_fbcnn_color/10/bikes.png b/test_results/LIVE1_color_fbcnn_color/10/bikes.png new file mode 100644 index 0000000..da78448 Binary files /dev/null and b/test_results/LIVE1_color_fbcnn_color/10/bikes.png differ diff --git a/test_results/LIVE1_color_fbcnn_color/10/building2.png b/test_results/LIVE1_color_fbcnn_color/10/building2.png new file mode 100644 index 0000000..8d450a5 Binary files /dev/null and b/test_results/LIVE1_color_fbcnn_color/10/building2.png differ diff --git a/test_results/LIVE1_color_fbcnn_color/10/buildings.png b/test_results/LIVE1_color_fbcnn_color/10/buildings.png new file mode 100644 index 0000000..53d29dc Binary files /dev/null and b/test_results/LIVE1_color_fbcnn_color/10/buildings.png differ diff --git a/test_results/LIVE1_color_fbcnn_color/10/caps.png b/test_results/LIVE1_color_fbcnn_color/10/caps.png new file mode 100644 index 0000000..cbcd33a Binary files /dev/null and b/test_results/LIVE1_color_fbcnn_color/10/caps.png differ diff --git a/test_results/LIVE1_color_fbcnn_color/10/carnivaldolls.png b/test_results/LIVE1_color_fbcnn_color/10/carnivaldolls.png new file mode 100644 index 0000000..1a605a8 Binary files /dev/null and b/test_results/LIVE1_color_fbcnn_color/10/carnivaldolls.png differ diff --git a/testsets/Classic5/1.bmp b/testsets/Classic5/1.bmp new file mode 100644 index 0000000..597bffe Binary files /dev/null and b/testsets/Classic5/1.bmp differ diff --git a/testsets/Classic5/2.bmp b/testsets/Classic5/2.bmp new file mode 100644 index 0000000..f1872c8 Binary files /dev/null and b/testsets/Classic5/2.bmp differ diff --git a/testsets/Classic5/3.bmp b/testsets/Classic5/3.bmp new file mode 100644 index 0000000..84f43e8 Binary files /dev/null and b/testsets/Classic5/3.bmp differ diff --git a/testsets/Classic5/4.bmp b/testsets/Classic5/4.bmp new file mode 100644 index 0000000..89496fe Binary files /dev/null and b/testsets/Classic5/4.bmp differ diff --git a/testsets/Classic5/5.bmp b/testsets/Classic5/5.bmp new file mode 100644 index 0000000..dd05a1b Binary files /dev/null and b/testsets/Classic5/5.bmp differ diff --git a/testsets/LIVE1_color/bikes.bmp b/testsets/LIVE1_color/bikes.bmp new file mode 100644 index 0000000..cc400a2 Binary files /dev/null and b/testsets/LIVE1_color/bikes.bmp differ diff --git a/testsets/LIVE1_color/building2.bmp b/testsets/LIVE1_color/building2.bmp new file mode 100644 index 0000000..f7258f8 Binary files /dev/null and b/testsets/LIVE1_color/building2.bmp differ diff --git a/testsets/LIVE1_color/buildings.bmp b/testsets/LIVE1_color/buildings.bmp new file mode 100644 index 0000000..77ad07c Binary files /dev/null and b/testsets/LIVE1_color/buildings.bmp differ diff --git a/testsets/LIVE1_color/caps.bmp b/testsets/LIVE1_color/caps.bmp new file mode 100644 index 0000000..da57d41 Binary files /dev/null and b/testsets/LIVE1_color/caps.bmp differ diff --git a/testsets/LIVE1_color/carnivaldolls.bmp b/testsets/LIVE1_color/carnivaldolls.bmp new file mode 100644 index 0000000..5e8ec42 Binary files /dev/null and b/testsets/LIVE1_color/carnivaldolls.bmp differ diff --git a/testsets/LIVE1_color/cemetry.bmp b/testsets/LIVE1_color/cemetry.bmp new file mode 100644 index 0000000..fb50090 Binary files /dev/null and b/testsets/LIVE1_color/cemetry.bmp differ diff --git a/testsets/LIVE1_color/churchandcapitol.bmp b/testsets/LIVE1_color/churchandcapitol.bmp new file mode 100644 index 0000000..e917d69 Binary files /dev/null and b/testsets/LIVE1_color/churchandcapitol.bmp differ diff --git a/testsets/LIVE1_color/coinsinfountain.bmp b/testsets/LIVE1_color/coinsinfountain.bmp new file mode 100644 index 0000000..f482091 Binary files /dev/null and b/testsets/LIVE1_color/coinsinfountain.bmp differ diff --git a/testsets/LIVE1_color/dancers.bmp b/testsets/LIVE1_color/dancers.bmp new file mode 100644 index 0000000..3709b3c Binary files /dev/null and b/testsets/LIVE1_color/dancers.bmp differ diff --git a/testsets/LIVE1_color/flowersonih35.bmp b/testsets/LIVE1_color/flowersonih35.bmp new file mode 100644 index 0000000..d37a203 Binary files /dev/null and b/testsets/LIVE1_color/flowersonih35.bmp differ diff --git a/testsets/LIVE1_color/house.bmp b/testsets/LIVE1_color/house.bmp new file mode 100644 index 0000000..c5d558e Binary files /dev/null and b/testsets/LIVE1_color/house.bmp differ diff --git a/testsets/LIVE1_color/lighthouse2.bmp b/testsets/LIVE1_color/lighthouse2.bmp new file mode 100644 index 0000000..a0c9ed0 Binary files /dev/null and b/testsets/LIVE1_color/lighthouse2.bmp differ diff --git a/testsets/LIVE1_color/lighthouse3.bmp b/testsets/LIVE1_color/lighthouse3.bmp new file mode 100644 index 0000000..de759b9 Binary files /dev/null and b/testsets/LIVE1_color/lighthouse3.bmp differ diff --git a/testsets/LIVE1_color/manfishing.bmp b/testsets/LIVE1_color/manfishing.bmp new file mode 100644 index 0000000..2da705d Binary files /dev/null and b/testsets/LIVE1_color/manfishing.bmp differ diff --git a/testsets/LIVE1_color/monarch.bmp b/testsets/LIVE1_color/monarch.bmp new file mode 100644 index 0000000..85b01ef Binary files /dev/null and b/testsets/LIVE1_color/monarch.bmp differ diff --git a/testsets/LIVE1_color/ocean.bmp b/testsets/LIVE1_color/ocean.bmp new file mode 100644 index 0000000..81ee901 Binary files /dev/null and b/testsets/LIVE1_color/ocean.bmp differ diff --git a/testsets/LIVE1_color/paintedhouse.bmp b/testsets/LIVE1_color/paintedhouse.bmp new file mode 100644 index 0000000..80a3d00 Binary files /dev/null and b/testsets/LIVE1_color/paintedhouse.bmp differ diff --git a/testsets/LIVE1_color/parrots.bmp b/testsets/LIVE1_color/parrots.bmp new file mode 100644 index 0000000..015a7e2 Binary files /dev/null and b/testsets/LIVE1_color/parrots.bmp differ diff --git a/testsets/LIVE1_color/plane.bmp b/testsets/LIVE1_color/plane.bmp new file mode 100644 index 0000000..3281b88 Binary files /dev/null and b/testsets/LIVE1_color/plane.bmp differ diff --git a/testsets/LIVE1_color/rapids.bmp b/testsets/LIVE1_color/rapids.bmp new file mode 100644 index 0000000..d62bf0b Binary files /dev/null and b/testsets/LIVE1_color/rapids.bmp differ diff --git a/testsets/LIVE1_color/sailing1.bmp b/testsets/LIVE1_color/sailing1.bmp new file mode 100644 index 0000000..19290dd Binary files /dev/null and b/testsets/LIVE1_color/sailing1.bmp differ diff --git a/testsets/LIVE1_color/sailing2.bmp b/testsets/LIVE1_color/sailing2.bmp new file mode 100644 index 0000000..00862dd Binary files /dev/null and b/testsets/LIVE1_color/sailing2.bmp differ diff --git a/testsets/LIVE1_color/sailing3.bmp b/testsets/LIVE1_color/sailing3.bmp new file mode 100644 index 0000000..bec8f69 Binary files /dev/null and b/testsets/LIVE1_color/sailing3.bmp differ diff --git a/testsets/LIVE1_color/sailing4.bmp b/testsets/LIVE1_color/sailing4.bmp new file mode 100644 index 0000000..523a9b0 Binary files /dev/null and b/testsets/LIVE1_color/sailing4.bmp differ diff --git a/testsets/LIVE1_color/statue.bmp b/testsets/LIVE1_color/statue.bmp new file mode 100644 index 0000000..80dd972 Binary files /dev/null and b/testsets/LIVE1_color/statue.bmp differ diff --git a/testsets/LIVE1_color/stream.bmp b/testsets/LIVE1_color/stream.bmp new file mode 100644 index 0000000..5bb5c41 Binary files /dev/null and b/testsets/LIVE1_color/stream.bmp differ diff --git a/testsets/LIVE1_color/studentsculpture.bmp b/testsets/LIVE1_color/studentsculpture.bmp new file mode 100644 index 0000000..52e58dd Binary files /dev/null and b/testsets/LIVE1_color/studentsculpture.bmp differ diff --git a/testsets/LIVE1_color/woman.bmp b/testsets/LIVE1_color/woman.bmp new file mode 100644 index 0000000..8f94b9c Binary files /dev/null and b/testsets/LIVE1_color/woman.bmp differ diff --git a/testsets/LIVE1_color/womanhat.bmp b/testsets/LIVE1_color/womanhat.bmp new file mode 100644 index 0000000..5ba7e8d Binary files /dev/null and b/testsets/LIVE1_color/womanhat.bmp differ diff --git a/testsets/LIVE1_gray/bikes.png b/testsets/LIVE1_gray/bikes.png new file mode 100644 index 0000000..3f4b118 Binary files /dev/null and b/testsets/LIVE1_gray/bikes.png differ diff --git a/testsets/LIVE1_gray/building2.png b/testsets/LIVE1_gray/building2.png new file mode 100644 index 0000000..4a032bc Binary files /dev/null and b/testsets/LIVE1_gray/building2.png differ diff --git a/testsets/LIVE1_gray/buildings.png b/testsets/LIVE1_gray/buildings.png new file mode 100644 index 0000000..89d2770 Binary files /dev/null and b/testsets/LIVE1_gray/buildings.png differ diff --git a/testsets/LIVE1_gray/caps.png b/testsets/LIVE1_gray/caps.png new file mode 100644 index 0000000..7e2ba0c Binary files /dev/null and b/testsets/LIVE1_gray/caps.png differ diff --git a/testsets/LIVE1_gray/carnivaldolls.png b/testsets/LIVE1_gray/carnivaldolls.png new file mode 100644 index 0000000..72783b3 Binary files /dev/null and b/testsets/LIVE1_gray/carnivaldolls.png differ diff --git a/testsets/LIVE1_gray/cemetry.png b/testsets/LIVE1_gray/cemetry.png new file mode 100644 index 0000000..47b2412 Binary files /dev/null and b/testsets/LIVE1_gray/cemetry.png differ diff --git a/testsets/LIVE1_gray/churchandcapitol.png b/testsets/LIVE1_gray/churchandcapitol.png new file mode 100644 index 0000000..a9b358f Binary files /dev/null and b/testsets/LIVE1_gray/churchandcapitol.png differ diff --git a/testsets/LIVE1_gray/coinsinfountain.png b/testsets/LIVE1_gray/coinsinfountain.png new file mode 100644 index 0000000..4cc4a3a Binary files /dev/null and b/testsets/LIVE1_gray/coinsinfountain.png differ diff --git a/testsets/LIVE1_gray/dancers.png b/testsets/LIVE1_gray/dancers.png new file mode 100644 index 0000000..77dabc4 Binary files /dev/null and b/testsets/LIVE1_gray/dancers.png differ diff --git a/testsets/LIVE1_gray/flowersonih35.png b/testsets/LIVE1_gray/flowersonih35.png new file mode 100644 index 0000000..c1e22e0 Binary files /dev/null and b/testsets/LIVE1_gray/flowersonih35.png differ diff --git a/testsets/LIVE1_gray/house.png b/testsets/LIVE1_gray/house.png new file mode 100644 index 0000000..da169d0 Binary files /dev/null and b/testsets/LIVE1_gray/house.png differ diff --git a/testsets/LIVE1_gray/lighthouse2.png b/testsets/LIVE1_gray/lighthouse2.png new file mode 100644 index 0000000..144731e Binary files /dev/null and b/testsets/LIVE1_gray/lighthouse2.png differ diff --git a/testsets/LIVE1_gray/lighthouse3.png b/testsets/LIVE1_gray/lighthouse3.png new file mode 100644 index 0000000..d7cfde1 Binary files /dev/null and b/testsets/LIVE1_gray/lighthouse3.png differ diff --git a/testsets/LIVE1_gray/manfishing.png b/testsets/LIVE1_gray/manfishing.png new file mode 100644 index 0000000..863bd1f Binary files /dev/null and b/testsets/LIVE1_gray/manfishing.png differ diff --git a/testsets/LIVE1_gray/monarch.png b/testsets/LIVE1_gray/monarch.png new file mode 100644 index 0000000..82f6cc0 Binary files /dev/null and b/testsets/LIVE1_gray/monarch.png differ diff --git a/testsets/LIVE1_gray/ocean.png b/testsets/LIVE1_gray/ocean.png new file mode 100644 index 0000000..7f2763e Binary files /dev/null and b/testsets/LIVE1_gray/ocean.png differ diff --git a/testsets/LIVE1_gray/paintedhouse.png b/testsets/LIVE1_gray/paintedhouse.png new file mode 100644 index 0000000..e842ddf Binary files /dev/null and b/testsets/LIVE1_gray/paintedhouse.png differ diff --git a/testsets/LIVE1_gray/parrots.png b/testsets/LIVE1_gray/parrots.png new file mode 100644 index 0000000..6f02bc5 Binary files /dev/null and b/testsets/LIVE1_gray/parrots.png differ diff --git a/testsets/LIVE1_gray/plane.png b/testsets/LIVE1_gray/plane.png new file mode 100644 index 0000000..0eb43a3 Binary files /dev/null and b/testsets/LIVE1_gray/plane.png differ diff --git a/testsets/LIVE1_gray/rapids.png b/testsets/LIVE1_gray/rapids.png new file mode 100644 index 0000000..f4145ea Binary files /dev/null and b/testsets/LIVE1_gray/rapids.png differ diff --git a/testsets/LIVE1_gray/sailing1.png b/testsets/LIVE1_gray/sailing1.png new file mode 100644 index 0000000..83804c5 Binary files /dev/null and b/testsets/LIVE1_gray/sailing1.png differ diff --git a/testsets/LIVE1_gray/sailing2.png b/testsets/LIVE1_gray/sailing2.png new file mode 100644 index 0000000..1deb4f8 Binary files /dev/null and b/testsets/LIVE1_gray/sailing2.png differ diff --git a/testsets/LIVE1_gray/sailing3.png b/testsets/LIVE1_gray/sailing3.png new file mode 100644 index 0000000..d1d8240 Binary files /dev/null and b/testsets/LIVE1_gray/sailing3.png differ diff --git a/testsets/LIVE1_gray/sailing4.png b/testsets/LIVE1_gray/sailing4.png new file mode 100644 index 0000000..add91bb Binary files /dev/null and b/testsets/LIVE1_gray/sailing4.png differ diff --git a/testsets/LIVE1_gray/statue.png b/testsets/LIVE1_gray/statue.png new file mode 100644 index 0000000..49e6c26 Binary files /dev/null and b/testsets/LIVE1_gray/statue.png differ diff --git a/testsets/LIVE1_gray/stream.png b/testsets/LIVE1_gray/stream.png new file mode 100644 index 0000000..6f6771e Binary files /dev/null and b/testsets/LIVE1_gray/stream.png differ diff --git a/testsets/LIVE1_gray/studentsculpture.png b/testsets/LIVE1_gray/studentsculpture.png new file mode 100644 index 0000000..e3748dd Binary files /dev/null and b/testsets/LIVE1_gray/studentsculpture.png differ diff --git a/testsets/LIVE1_gray/woman.png b/testsets/LIVE1_gray/woman.png new file mode 100644 index 0000000..5dcc893 Binary files /dev/null and b/testsets/LIVE1_gray/woman.png differ diff --git a/testsets/LIVE1_gray/womanhat.png b/testsets/LIVE1_gray/womanhat.png new file mode 100644 index 0000000..545e74f Binary files /dev/null and b/testsets/LIVE1_gray/womanhat.png differ diff --git a/utils/__pycache__/utils_image.cpython-37.pyc b/utils/__pycache__/utils_image.cpython-37.pyc new file mode 100644 index 0000000..2012c01 Binary files /dev/null and b/utils/__pycache__/utils_image.cpython-37.pyc differ diff --git a/utils/__pycache__/utils_logger.cpython-37.pyc b/utils/__pycache__/utils_logger.cpython-37.pyc new file mode 100644 index 0000000..e3d0637 Binary files /dev/null and b/utils/__pycache__/utils_logger.cpython-37.pyc differ diff --git a/utils/utils_image.py b/utils/utils_image.py new file mode 100644 index 0000000..8fe0d5b --- /dev/null +++ b/utils/utils_image.py @@ -0,0 +1,999 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +# import torchvision.transforms as transforms +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) + # print(w1) + # print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# BEF: Blocking effect factor +# -------------------------------------------- +def compute_bef(img): + + block = 8 + height, width = img.shape[:2] + + H = [i for i in range(width-1)] + H_B = [i for i in range(block-1,width-1,block)] + H_BC = list(set(H)-set(H_B)) + + V = [i for i in range(height-1)] + V_B = [i for i in range(block-1,height-1,block)] + V_BC = list(set(V)-set(V_B)) + + D_B = 0 + D_BC = 0 + + for i in H_B: + diff = img[:,i] - img[:,i+1] + D_B += np.sum(diff**2) + + for i in H_BC: + diff = img[:,i] - img[:,i+1] + D_BC += np.sum(diff**2) + + + for j in V_B: + diff = img[j,:] - img[j+1,:] + D_B += np.sum(diff**2) + + for j in V_BC: + diff = img[j,:] - img[j+1,:] + D_BC += np.sum(diff**2) + + + N_HB = height * (width/block - 1) + N_HBC = height * (width - 1) - N_HB + N_VB = width * (height/block -1) + N_VBC = width * (height -1) - N_VB + D_B = D_B / (N_HB + N_VB) + D_BC = D_BC / (N_HBC + N_VBC) + eta = math.log2(block) / math.log2(min(height, width)) if D_B > D_BC else 0 + return eta * (D_B - D_BC) + + + +# -------------------------------------------- +# PSNRB +# -------------------------------------------- +def calculate_psnrb(img1, img2, border=0): + # img1: ground truth + # img2: compressed image + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + img1 = img1.astype(np.float64) + if img2.shape[-1]==3: + img2_y = rgb2ycbcr(img2).astype(np.float64) + bef = compute_bef(img2_y) + else: + img2 = img2.astype(np.float64) + bef = compute_bef(img2) + mse = np.mean((img1 - img2)**2) + mse_b = mse + bef + if mse_b == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse_b)) + + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) +# imshow(single2uint(img_bicubic)) +# +# img_tensor = single2tensor4(img) +# for i in range(8): +# imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1)) + +# patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200) +# imssave(patches,'a.png') \ No newline at end of file diff --git a/utils/utils_logger.py b/utils/utils_logger.py new file mode 100644 index 0000000..3067190 --- /dev/null +++ b/utils/utils_logger.py @@ -0,0 +1,66 @@ +import sys +import datetime +import logging + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +def log(*args, **kwargs): + print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) + + +''' +# -------------------------------------------- +# logger +# -------------------------------------------- +''' + + +def logger_info(logger_name, log_path='default_logger.log'): + ''' set up logger + modified by Kai Zhang (github: https://github.com/cszn) + ''' + log = logging.getLogger(logger_name) + if log.hasHandlers(): + print('LogHandlers exist!') + else: + print('LogHandlers setup!') + level = logging.INFO + formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') + fh = logging.FileHandler(log_path, mode='a') + fh.setFormatter(formatter) + log.setLevel(level) + log.addHandler(fh) + # print(len(log.handlers)) + + sh = logging.StreamHandler() + sh.setFormatter(formatter) + log.addHandler(sh) + + +''' +# -------------------------------------------- +# print to file and std_out simultaneously +# -------------------------------------------- +''' + + +class logger_print(object): + def __init__(self, log_path="default.log"): + self.terminal = sys.stdout + self.log = open(log_path, 'a') + + def write(self, message): + self.terminal.write(message) + self.log.write(message) # write the message + + def flush(self): + pass diff --git a/utils/utils_model.py b/utils/utils_model.py new file mode 100644 index 0000000..4aa678f --- /dev/null +++ b/utils/utils_model.py @@ -0,0 +1,370 @@ +# -*- coding: utf-8 -*- +import numpy as np +import torch +from utils import utils_image as util +import re +import glob +import os +from IPython import embed + +''' +# -------------------------------------------- +# Model +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +''' + + +def find_last_checkpoint(save_dir, net_type='G'): + """ + # --------------------------------------- + # Kai Zhang (github: https://github.com/cszn) + # 03/Mar/2019 + # --------------------------------------- + Args: + save_dir: model folder + net_type: 'G' or 'D' + + Return: + init_iter: iteration number + init_path: model path + # --------------------------------------- + """ + file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) + if file_list: + iter_exist = [] + for file_ in file_list: + iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) + iter_exist.append(int(iter_current[0])) + init_iter = max(iter_exist) + init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) + else: + init_iter = 0 + init_path = None + return init_iter, init_path + + +def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1): + ''' + # --------------------------------------- + # Kai Zhang (github: https://github.com/cszn) + # 03/Mar/2019 + # --------------------------------------- + Args: + model: trained model + L: input Low-quality image + mode: + (0) normal: test(model, L) + (1) pad: test_pad(model, L, modulo=16) + (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1) + (3) x8: test_x8(model, L, modulo=1) ^_^ + (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1) + refield: effective receptive filed of the network, 32 is enough + useful when split, i.e., mode=2, 4 + min_size: min_sizeXmin_size image, e.g., 256X256 image + useful when split, i.e., mode=2, 4 + sf: scale factor for super-resolution, otherwise 1 + modulo: 1 if split + useful when pad, i.e., mode=1 + + Returns: + E: estimated image + # --------------------------------------- + ''' + if mode == 0: + E = test(model, L) + elif mode == 1: + E = test_pad(model, L, modulo, sf) + elif mode == 2: + E = test_split(model, L, refield, min_size, sf, modulo) + elif mode == 3: + E = test_x8(model, L, modulo, sf) + elif mode == 4: + E = test_split_x8(model, L, refield, min_size, sf, modulo) + return E + + +''' +# -------------------------------------------- +# normal (0) +# -------------------------------------------- +''' + + +def test(model, L): + E = model(L) + return E + + +''' +# -------------------------------------------- +# pad (1) +# -------------------------------------------- +''' + + +def test_pad(model, L, modulo=16, sf=1): + h, w = L.size()[-2:] + paddingBottom = int(np.ceil(h/modulo)*modulo-h) + paddingRight = int(np.ceil(w/modulo)*modulo-w) + L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L) + E,QF = model(L) + E = E[..., :h*sf, :w*sf] + return E + +def test_pad_deblocking(model, L, modulo=16): +# embed() + h, w = L.size()[-2:] + E0 = model(L) + paddingH = int(h-np.floor(h/modulo)*modulo) + paddingW = int(w-np.floor(w/modulo)*modulo) +# embed() + top = slice(0,h-paddingH) + top_c = slice(h-paddingH,h) + bottom = slice(paddingH,h) + bottom_c = slice(0,paddingH) + left = slice(0,w-paddingW) + left_c = slice(w-paddingW, w) + right = slice(paddingW,w) + right_c = slice(0,paddingW) + L1 = L[...,top,left] +# L2 = L[...,top,right] +# embed() +# L3 = L[...,bottom,left] +# L4 = L[...,bottom,right] + E1 = test_mode(model, L1, mode=3) +# E2 = test_mode(model, L2, mode=3) + E0[...,top,left] = E1 +# embed() +# E0[...,top,right] = E2 +# E0[...,top,left_c] = E2[...,top,-1*paddingW:] +# L1 = torch.nn.ZeroPad2d((0, paddingW, 0, paddingH))(L) +# L1 = torch.nn.ConstantPad2d((0, paddingW, 0, paddingH),0)(L) +# E1 = model(L1)[..., :h, :w] +# L2 = torch.nn.ZeroPad2d((paddingW,0 , 0, paddingH))(L) +# E2 = model(L2)[..., :h, paddingW:] +# L3 = torch.nn.ZeroPad2d((0, paddingW, paddingH, 0))(L) +# E3 = model(L3)[..., paddingH:, :w] +# L4 = torch.nn.ZeroPad2d((paddingW,0 , paddingH,0))(L) +# E4 = model(L4)[..., paddingH:, paddingW:] +# embed() + E_list = [E0] + output_cat = torch.stack(E_list, dim=0) + E = output_cat.mean(dim=0, keepdim=False) + return E + + + +''' +# -------------------------------------------- +# split (function) +# -------------------------------------------- +''' + + +def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1): + """ + Args: + model: trained model + L: input Low-quality image + refield: effective receptive filed of the network, 32 is enough + min_size: min_sizeXmin_size image, e.g., 256X256 image + sf: scale factor for super-resolution, otherwise 1 + modulo: 1 if split + + Returns: + E: estimated result + """ + h, w = L.size()[-2:] + if h*w <= min_size**2: + L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L) + E = model(L) + E = E[..., :h*sf, :w*sf] + else: + top = slice(0, (h//2//refield+1)*refield) + bottom = slice(h - (h//2//refield+1)*refield, h) + left = slice(0, (w//2//refield+1)*refield) + right = slice(w - (w//2//refield+1)*refield, w) + Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]] + + if h * w <= 4*(min_size**2): + Es = [model(Ls[i]) for i in range(4)] + else: + Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)] + + b, c = Es[0].size()[:2] + E = torch.zeros(b, c, sf * h, sf * w).type_as(L) + + E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf] + E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:] + E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf] + E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:] + return E + + +''' +# -------------------------------------------- +# split (2) +# -------------------------------------------- +''' + + +def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1): + E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo) + return E + + +''' +# -------------------------------------------- +# x8 (3) +# -------------------------------------------- +''' + + +def test_x8(model, L, modulo=1, sf=1): + E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)] + for i in range(len(E_list)): + if i == 3 or i == 5: + E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i) + else: + E_list[i] = util.augment_img_tensor4(E_list[i], mode=i) + output_cat = torch.stack(E_list, dim=0) + E = output_cat.mean(dim=0, keepdim=False) + return E + + +''' +# -------------------------------------------- +# split and x8 (4) +# -------------------------------------------- +''' + + +def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1): + E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)] + for k, i in enumerate(range(len(E_list))): + if i==3 or i==5: + E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i) + else: + E_list[k] = util.augment_img_tensor4(E_list[k], mode=i) + output_cat = torch.stack(E_list, dim=0) + E = output_cat.mean(dim=0, keepdim=False) + return E + + +''' +# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- +# _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^ +# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- +''' + + +''' +# -------------------------------------------- +# print +# -------------------------------------------- +''' + + +# -------------------------------------------- +# print model +# -------------------------------------------- +def print_model(model): + msg = describe_model(model) + print(msg) + + +# -------------------------------------------- +# print params +# -------------------------------------------- +def print_params(model): + msg = describe_params(model) + print(msg) + + +''' +# -------------------------------------------- +# information +# -------------------------------------------- +''' + + +# -------------------------------------------- +# model inforation +# -------------------------------------------- +def info_model(model): + msg = describe_model(model) + return msg + + +# -------------------------------------------- +# params inforation +# -------------------------------------------- +def info_params(model): + msg = describe_params(model) + return msg + + +''' +# -------------------------------------------- +# description +# -------------------------------------------- +''' + + +# -------------------------------------------- +# model name and total number of parameters +# -------------------------------------------- +def describe_model(model): + if isinstance(model, torch.nn.DataParallel): + model = model.module + msg = '\n' + msg += 'models name: {}'.format(model.__class__.__name__) + '\n' + msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n' + msg += 'Net structure:\n{}'.format(str(model)) + '\n' + return msg + + +# -------------------------------------------- +# parameters description +# -------------------------------------------- +def describe_params(model): + if isinstance(model, torch.nn.DataParallel): + model = model.module + msg = '\n' + msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'param_name') + '\n' + for name, param in model.state_dict().items(): + if not 'num_batches_tracked' in name: + v = param.data.clone().float() + msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), name) + '\n' + return msg + + +if __name__ == '__main__': + + class Net(torch.nn.Module): + def __init__(self, in_channels=3, out_channels=3): + super(Net, self).__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) + + def forward(self, x): + x = self.conv(x) + return x + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + model = Net() + model = model.eval() + print_model(model) + print_params(model) + x = torch.randn((2,3,401,401)) + torch.cuda.empty_cache() + with torch.no_grad(): + for mode in range(5): + y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1) + print(y.shape) + + # run utils/utils_model.py \ No newline at end of file