-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2cd9408
Showing
96 changed files
with
2,347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
## Towards Flexible Blind JPEG Artifacts Removal (ICCV 2021) (PyTorch) | ||
|
||
Jiaxi Jiang, Kai Zhang, Radu Timofte | ||
|
||
Computer Vision Lab, ETH Zurich, Switzerland | ||
|
||
________ | ||
✨ _**Some visual examples (click the images for more details)**_: | ||
|
||
[<img src="figs/v3.png" width="400px"/>](https://imgsli.com/NzA3NjI) | ||
[<img src="figs/v1.png" width="400px"/>](https://imgsli.com/NzA3NTk) | ||
[<img src="figs/v2.png" width="400px"/>](https://imgsli.com/NzA3NjE) | ||
[<img src="figs/v4.png" width="400px"/>](https://imgsli.com/NzA3NjM) | ||
[<img src="figs/v5.png" width="400px"/>](https://imgsli.com/NzA3NjQ) | ||
[<img src="figs/v6.png" width="400px"/>](https://imgsli.com/NzA3NjU) | ||
|
||
________ | ||
|
||
### 1. Motivations | ||
JPEG is one of the most widely-used image compression algorithms and formats due to its simplicity and fast encoding/decoding speeds. However, it is a lossy compression algorithm and can introduce annoying artifacts. Existing methods for JPEG artifacts removal generally have four limitations in real applications: | ||
|
||
a. Most existing learning-based methods [e.g. ARCNN, MWCNN, SwinIR] trained a specific model for each quality factor, lacking the flexibility to learn a single model for different JPEG quality factors. | ||
|
||
b. DCT-based methods [e.g. DMCNN, QGAC] need to obtain the DCT coefficients or quantization table as input, which is only stored in JPEG format. Besides, when images are compressed multiple times, only the most recent compression information is stored. | ||
|
||
c. Existing blind methods [e.g. DnCNN, DCSC, QGAC] can only provide a deterministic reconstruction result for each input, ignoring the need for user preferences. | ||
|
||
d. Existing methods are all trained with synthetic images which assumes that the low-quality images are compressed only once. However, most images from the Internet are compressed multiple times. Despite some progress for real recompressed images, e.g. from Twitter [ARCNN, DCSC], a detailed and complete study on double JPEG artifacts removal is still missing. | ||
|
||
|
||
### 2. Network Architecture | ||
We propose a flexible blind convolutional neural network (FBCNN) that predicts the quality factor of a JPEG image and embed it into the decoder to guide image restoration. The quality factor can be manually adjusted for flexible JPEG restoration according to the user's preference. | ||
![architecture](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/architecture.png) | ||
|
||
### 3. Consideration on Restoration of Double JPEG Restoration | ||
#### a. Limitation of Existing Blind Methods | ||
We find that existing blind methods always do not work when the 8x8 blocks of two JPEG compression are not aligned and QF1 <= QF2, _**even with just a one-pixel shift.**_ Other cases such as non-aligned double JPEG with QF1>QF2, or aligned double JPEG compression, are actually equivalent to single JPEG compression. | ||
|
||
Here is an example of the restoration result of DnCNN and QGAC on a JPEG image with different degradation settings. '*' means there is a one-pixel shift between two JPEG blocks. | ||
![lena_doublejpeg](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/lena_doublejpeg.png) | ||
|
||
|
||
#### b. Our Solutions | ||
We find for non-aligned double JPEG images with QF1 < QF2, FBCNN always predicts the quality factor as QF2. However, it is the smaller QF1 that dominants the compression artifacts. By manually changing the predicted quality factor to QF1, we largely improve the result. | ||
|
||
Besides, to get a fully blind model, we propose two blind solutions to solve this problem: | ||
|
||
(1) FBCNN-D: Train a model with a single JPEG degradation model + automatic dominant QF correction. By utilizing the property of JPEG images, we find the quality factor of a single JPEG image can be predicted by applying another JPEG compression. When QF1 = QF2, the MSE of two JPEG images is minimal. In our paper, we also extend this method to non-aligned double JPEG cases to get a fully blind model. | ||
|
||
(2) FBCNN-A: Augment training data with double JPEG degradation model, which is given by: | ||
|
||
<p align="center"> | ||
y = JPEG(shift(JPEG(x, QF1)),QF2) | ||
</p> | ||
|
||
By reducing the misalignment of training data and real-world JPEG images, FBCNN-A further improves the results on complex double JPEG restoration. _**This proposed double JPEG degradation model can be easily integrated into other image restoration tasks, such as Single Image Super-Resolution, for better general real image restoration.**_ | ||
|
||
To the best of our knowledge, we are the first to tackle the problem of restoration of non-aligned double JPEG compression. As JPEG is the most widely used image compression algorithm and format, and most real-world JPEG images are compressed many times, we believe it would be a significant step towards real image restoration. | ||
|
||
### 3. Experiments | ||
|
||
#### a. Single JPEG Restoration | ||
![single_table](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/single_table.png) | ||
*: Train a specific model for each quality factor. | ||
![single_compare](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/single_compare.png) | ||
|
||
#### b. Non-Aligned Double JPEG Restoration | ||
There is a pixel shift of (4,4) between the blocks of two JPEG compression. | ||
![double_table](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/double_table.png) | ||
![double_compare](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/double_compare.png) | ||
|
||
#### c. Real JPEG Restoration | ||
![real](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/real.png) | ||
|
||
#### d. Flexibility of FBCNN | ||
By setting different quality factors, we can control the trade-off between artifacts removal and details preservation. | ||
![flexible](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/flexible.png) | ||
### 4. Training | ||
We will release the training code at [KAIR](https://github.com/cszn/KAIR/). | ||
|
||
### 5. Testing | ||
#### a. Grayscale Images (Calculate Metrics) | ||
Put the folder with uncompressed grayscale or Y channel images (Classic5, LIVE1, BSDS500, etc.) under `testsets`. This code generates compressed JPEG images and calculates PSNR, SSIM, PSNRB. | ||
|
||
###### Single JPEG Restoration | ||
```bash | ||
python main_test_fbcnn_gray.py | ||
``` | ||
###### Double JPEG Restoration | ||
```bash | ||
python main_test_fbcnn_gray_doublejpeg.py | ||
``` | ||
|
||
#### b. Color Images (Calculate Metrics) | ||
Put the folder with uncompressed images (LIVE1, BSDS500, ICB, etc.) under `testsets`. This code generates compressed JPEG images and calculates PSNR, SSIM, PSNRB. | ||
|
||
```bash | ||
python main_test_fbcnn_color.py | ||
``` | ||
#### c. Real-World Images (Real Application) | ||
Put the folder with real-world compressed images under `testsets`. This code restores the images without calculating metrics. Please note by setting different quality factors, we can control the trade-off between artifacts removal and details preservation. | ||
```bash | ||
python main_test_fbcnn_color_real.py | ||
``` | ||
|
||
### Citation | ||
@inproceedings{jiang2021towards, | ||
title={Towards Flexible Blind {JPEG} Artifacts Removal}, | ||
author={Jiang, Jiaxi and Zhang, Kai and Timofte, Radu}, | ||
booktitle={IEEE International Conference on Computer Vision}, | ||
year={2021} | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import os.path | ||
import logging | ||
import numpy as np | ||
from datetime import datetime | ||
from collections import OrderedDict | ||
import torch | ||
import cv2 | ||
from utils import utils_logger | ||
from utils import utils_image as util | ||
|
||
def main(): | ||
|
||
quality_factor_list = [10, 20, 30, 40, 50, 60, 70, 80, 90] | ||
testset_name = 'LIVE1_color' # 'LIVE1_color' 'BSDS500_color' 'ICB' | ||
n_channels = 3 # set 1 for grayscale image, set 3 for color image | ||
model_name = 'fbcnn_color.pth' | ||
nc = [64,128,256,512] | ||
nb = 4 | ||
show_img = False # default: False | ||
testsets = 'testsets' | ||
results = 'test_results' | ||
|
||
for quality_factor in quality_factor_list: | ||
|
||
result_name = testset_name + '_' + model_name[:-4] | ||
H_path = os.path.join(testsets, testset_name) | ||
E_path = os.path.join(results, result_name, str(quality_factor)) # E_path, for Estimated images | ||
util.mkdir(E_path) | ||
|
||
model_pool = 'model_zoo' # fixed | ||
model_path = os.path.join(model_pool, model_name) | ||
logger_name = result_name + '_qf_' + str(quality_factor) | ||
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) | ||
logger = logging.getLogger(logger_name) | ||
logger.info('--------------- quality factor: {:d} ---------------'.format(quality_factor)) | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
border = 0 | ||
|
||
|
||
# ---------------------------------------- | ||
# load model | ||
# ---------------------------------------- | ||
|
||
from models.network_fbcnn import FBCNN as net | ||
model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R') | ||
model.load_state_dict(torch.load(model_path), strict=True) | ||
model.eval() | ||
for k, v in model.named_parameters(): | ||
v.requires_grad = False | ||
model = model.to(device) | ||
logger.info('Model path: {:s}'.format(model_path)) | ||
|
||
test_results = OrderedDict() | ||
test_results['psnr'] = [] | ||
test_results['ssim'] = [] | ||
test_results['psnrb'] = [] | ||
|
||
H_paths = util.get_image_paths(H_path) | ||
for idx, img in enumerate(H_paths): | ||
|
||
# ------------------------------------ | ||
# (1) img_L | ||
# ------------------------------------ | ||
img_name, ext = os.path.splitext(os.path.basename(img)) | ||
logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) | ||
img_L = util.imread_uint(img, n_channels=n_channels) | ||
|
||
if n_channels == 3: | ||
img_L = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR) | ||
_, encimg = cv2.imencode('.jpg', img_L, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) | ||
img_L = cv2.imdecode(encimg, 0) if n_channels == 1 else cv2.imdecode(encimg, 3) | ||
if n_channels == 3: | ||
img_L = cv2.cvtColor(img_L, cv2.COLOR_BGR2RGB) | ||
img_L = util.uint2tensor4(img_L) | ||
img_L = img_L.to(device) | ||
|
||
# ------------------------------------ | ||
# (2) img_E | ||
# ------------------------------------ | ||
|
||
#img_E,QF = model(img_L, torch.tensor([[0.6]])) | ||
img_E,QF = model(img_L) | ||
QF = 1 - QF | ||
img_E = util.tensor2single(img_E) | ||
img_E = util.single2uint(img_E) | ||
img_H = util.imread_uint(H_paths[idx], n_channels=n_channels).squeeze() | ||
# -------------------------------- | ||
# PSNR and SSIM, PSNRB | ||
# -------------------------------- | ||
|
||
psnr = util.calculate_psnr(img_E, img_H, border=border) | ||
ssim = util.calculate_ssim(img_E, img_H, border=border) | ||
psnrb = util.calculate_psnrb(img_H, img_E, border=border) | ||
test_results['psnr'].append(psnr) | ||
test_results['ssim'].append(ssim) | ||
test_results['psnrb'].append(psnrb) | ||
logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.3f}; PSNRB: {:.2f} dB.'.format(img_name+ext, psnr, ssim, psnrb)) | ||
logger.info('predicted quality factor: {:d}'.format(round(float(QF*100)))) | ||
|
||
util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None | ||
util.imsave(img_E, os.path.join(E_path, img_name+'.png')) | ||
|
||
ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) | ||
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) | ||
ave_psnrb = sum(test_results['psnrb']) / len(test_results['psnrb']) | ||
logger.info( | ||
'Average PSNR/SSIM/PSNRB - {} -: {:.2f}$\\vert${:.4f}$\\vert${:.2f}.'.format(result_name+'_'+str(quality_factor), ave_psnr, ave_ssim, ave_psnrb)) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import os.path | ||
import logging | ||
import numpy as np | ||
from datetime import datetime | ||
from collections import OrderedDict | ||
import torch | ||
import cv2 | ||
from utils import utils_logger | ||
from utils import utils_image as util | ||
|
||
def main(): | ||
|
||
testset_name = 'Real' # folder name of real images | ||
n_channels = 3 # set 1 for grayscale image, set 3 for color image | ||
model_name = 'fbcnn_color.pth' | ||
nc = [64,128,256,512] | ||
nb = 4 | ||
testsets = 'testsets' | ||
results = 'test_results' | ||
|
||
do_flexible_control = True | ||
QF_control = [10,30,50,70,90] # adjust qf as input to provide different results | ||
|
||
result_name = testset_name + '_' + model_name[:-4] | ||
L_path = os.path.join(testsets, testset_name) | ||
E_path = os.path.join(results, result_name) # E_path, for Estimated images | ||
util.mkdir(E_path) | ||
|
||
model_pool = 'model_zoo' # fixed | ||
model_path = os.path.join(model_pool, model_name) | ||
logger_name = result_name | ||
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) | ||
logger = logging.getLogger(logger_name) | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
border = 0 | ||
|
||
|
||
# ---------------------------------------- | ||
# load model | ||
# ---------------------------------------- | ||
|
||
from models.network_fbcnn import FBCNN as net | ||
model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='BR') | ||
model.load_state_dict(torch.load(model_path), strict=True) | ||
model.eval() | ||
for k, v in model.named_parameters(): | ||
v.requires_grad = False | ||
model = model.to(device) | ||
logger.info('Model path: {:s}'.format(model_path)) | ||
|
||
test_results = OrderedDict() | ||
test_results['psnr'] = [] | ||
test_results['ssim'] = [] | ||
test_results['psnrb'] = [] | ||
|
||
L_paths = util.get_image_paths(L_path) | ||
for idx, img in enumerate(L_paths): | ||
|
||
# ------------------------------------ | ||
# (1) img_L | ||
# ------------------------------------ | ||
img_name, ext = os.path.splitext(os.path.basename(img)) | ||
logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) | ||
img_L = util.imread_uint(img, n_channels=n_channels) | ||
|
||
img_L = util.uint2tensor4(img_L) | ||
img_L = img_L.to(device) | ||
|
||
# ------------------------------------ | ||
# (2) img_E | ||
# ------------------------------------ | ||
|
||
#img_E,QF = model(img_L, torch.tensor([[0.6]])) | ||
img_E,QF = model(img_L) | ||
QF = 1- QF | ||
img_E = util.tensor2single(img_E) | ||
img_E = util.single2uint(img_E) | ||
logger.info('predicted quality factor: {:d}'.format(round(float(QF*100)))) | ||
util.imsave(img_E, os.path.join(E_path, img_name+'.png')) | ||
|
||
if do_flexible_control: | ||
for QF_set in QF_control: | ||
logger.info('Flexible control by QF = {:d}'.format(QF_set)) | ||
# from IPython import embed; embed() | ||
qf_input = torch.tensor([[1-QF_set/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-QF_set/100]]) | ||
img_E,QF = model(img_L, qf_input) | ||
QF = 1- QF | ||
img_E = util.tensor2single(img_E) | ||
img_E = util.single2uint(img_E) | ||
util.imsave(img_E, os.path.join(E_path, img_name + '_qf_'+ str(QF_set)+'.png')) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.