Skip to content

Commit

Permalink
release the code of FBCNN
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaxi-jiang committed Sep 13, 2021
0 parents commit 2cd9408
Show file tree
Hide file tree
Showing 96 changed files with 2,347 additions and 0 deletions.
112 changes: 112 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
## Towards Flexible Blind JPEG Artifacts Removal (ICCV 2021) (PyTorch)

Jiaxi Jiang, Kai Zhang, Radu Timofte

Computer Vision Lab, ETH Zurich, Switzerland

________
_**Some visual examples (click the images for more details)**_:

[<img src="figs/v3.png" width="400px"/>](https://imgsli.com/NzA3NjI)
[<img src="figs/v1.png" width="400px"/>](https://imgsli.com/NzA3NTk)
[<img src="figs/v2.png" width="400px"/>](https://imgsli.com/NzA3NjE)
[<img src="figs/v4.png" width="400px"/>](https://imgsli.com/NzA3NjM)
[<img src="figs/v5.png" width="400px"/>](https://imgsli.com/NzA3NjQ)
[<img src="figs/v6.png" width="400px"/>](https://imgsli.com/NzA3NjU)

________

### 1. Motivations
JPEG is one of the most widely-used image compression algorithms and formats due to its simplicity and fast encoding/decoding speeds. However, it is a lossy compression algorithm and can introduce annoying artifacts. Existing methods for JPEG artifacts removal generally have four limitations in real applications:

a. Most existing learning-based methods [e.g. ARCNN, MWCNN, SwinIR] trained a specific model for each quality factor, lacking the flexibility to learn a single model for different JPEG quality factors.

b. DCT-based methods [e.g. DMCNN, QGAC] need to obtain the DCT coefficients or quantization table as input, which is only stored in JPEG format. Besides, when images are compressed multiple times, only the most recent compression information is stored.

c. Existing blind methods [e.g. DnCNN, DCSC, QGAC] can only provide a deterministic reconstruction result for each input, ignoring the need for user preferences.

d. Existing methods are all trained with synthetic images which assumes that the low-quality images are compressed only once. However, most images from the Internet are compressed multiple times. Despite some progress for real recompressed images, e.g. from Twitter [ARCNN, DCSC], a detailed and complete study on double JPEG artifacts removal is still missing.


### 2. Network Architecture
We propose a flexible blind convolutional neural network (FBCNN) that predicts the quality factor of a JPEG image and embed it into the decoder to guide image restoration. The quality factor can be manually adjusted for flexible JPEG restoration according to the user's preference.
![architecture](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/architecture.png)

### 3. Consideration on Restoration of Double JPEG Restoration
#### a. Limitation of Existing Blind Methods
We find that existing blind methods always do not work when the 8x8 blocks of two JPEG compression are not aligned and QF1 <= QF2, _**even with just a one-pixel shift.**_ Other cases such as non-aligned double JPEG with QF1>QF2, or aligned double JPEG compression, are actually equivalent to single JPEG compression.

Here is an example of the restoration result of DnCNN and QGAC on a JPEG image with different degradation settings. '*' means there is a one-pixel shift between two JPEG blocks.
![lena_doublejpeg](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/lena_doublejpeg.png)


#### b. Our Solutions
We find for non-aligned double JPEG images with QF1 < QF2, FBCNN always predicts the quality factor as QF2. However, it is the smaller QF1 that dominants the compression artifacts. By manually changing the predicted quality factor to QF1, we largely improve the result.

Besides, to get a fully blind model, we propose two blind solutions to solve this problem:

(1) FBCNN-D: Train a model with a single JPEG degradation model + automatic dominant QF correction. By utilizing the property of JPEG images, we find the quality factor of a single JPEG image can be predicted by applying another JPEG compression. When QF1 = QF2, the MSE of two JPEG images is minimal. In our paper, we also extend this method to non-aligned double JPEG cases to get a fully blind model.

(2) FBCNN-A: Augment training data with double JPEG degradation model, which is given by:

<p align="center">
y = JPEG(shift(JPEG(x, QF1)),QF2)
</p>

By reducing the misalignment of training data and real-world JPEG images, FBCNN-A further improves the results on complex double JPEG restoration. _**This proposed double JPEG degradation model can be easily integrated into other image restoration tasks, such as Single Image Super-Resolution, for better general real image restoration.**_

To the best of our knowledge, we are the first to tackle the problem of restoration of non-aligned double JPEG compression. As JPEG is the most widely used image compression algorithm and format, and most real-world JPEG images are compressed many times, we believe it would be a significant step towards real image restoration.

### 3. Experiments

#### a. Single JPEG Restoration
![single_table](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/single_table.png)
*: Train a specific model for each quality factor.
![single_compare](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/single_compare.png)

#### b. Non-Aligned Double JPEG Restoration
There is a pixel shift of (4,4) between the blocks of two JPEG compression.
![double_table](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/double_table.png)
![double_compare](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/double_compare.png)

#### c. Real JPEG Restoration
![real](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/real.png)

#### d. Flexibility of FBCNN
By setting different quality factors, we can control the trade-off between artifacts removal and details preservation.
![flexible](https://github.com/jiaxi-jiang/FBCNN/blob/main/figs/flexible.png)
### 4. Training
We will release the training code at [KAIR](https://github.com/cszn/KAIR/).

### 5. Testing
#### a. Grayscale Images (Calculate Metrics)
Put the folder with uncompressed grayscale or Y channel images (Classic5, LIVE1, BSDS500, etc.) under `testsets`. This code generates compressed JPEG images and calculates PSNR, SSIM, PSNRB.

###### Single JPEG Restoration
```bash
python main_test_fbcnn_gray.py
```
###### Double JPEG Restoration
```bash
python main_test_fbcnn_gray_doublejpeg.py
```

#### b. Color Images (Calculate Metrics)
Put the folder with uncompressed images (LIVE1, BSDS500, ICB, etc.) under `testsets`. This code generates compressed JPEG images and calculates PSNR, SSIM, PSNRB.

```bash
python main_test_fbcnn_color.py
```
#### c. Real-World Images (Real Application)
Put the folder with real-world compressed images under `testsets`. This code restores the images without calculating metrics. Please note by setting different quality factors, we can control the trade-off between artifacts removal and details preservation.
```bash
python main_test_fbcnn_color_real.py
```

### Citation
@inproceedings{jiang2021towards,
title={Towards Flexible Blind {JPEG} Artifacts Removal},
author={Jiang, Jiaxi and Zhang, Kai and Timofte, Radu},
booktitle={IEEE International Conference on Computer Vision},
year={2021}
}
Binary file added figs/architecture.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/double_compare.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/double_table.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/flexible.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/lena_doublejpeg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/real.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/single_compare.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/single_table.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/v1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/v3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/v4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/v5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/v6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
112 changes: 112 additions & 0 deletions main_test_fbcnn_color.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os.path
import logging
import numpy as np
from datetime import datetime
from collections import OrderedDict
import torch
import cv2
from utils import utils_logger
from utils import utils_image as util

def main():

quality_factor_list = [10, 20, 30, 40, 50, 60, 70, 80, 90]
testset_name = 'LIVE1_color' # 'LIVE1_color' 'BSDS500_color' 'ICB'
n_channels = 3 # set 1 for grayscale image, set 3 for color image
model_name = 'fbcnn_color.pth'
nc = [64,128,256,512]
nb = 4
show_img = False # default: False
testsets = 'testsets'
results = 'test_results'

for quality_factor in quality_factor_list:

result_name = testset_name + '_' + model_name[:-4]
H_path = os.path.join(testsets, testset_name)
E_path = os.path.join(results, result_name, str(quality_factor)) # E_path, for Estimated images
util.mkdir(E_path)

model_pool = 'model_zoo' # fixed
model_path = os.path.join(model_pool, model_name)
logger_name = result_name + '_qf_' + str(quality_factor)
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
logger = logging.getLogger(logger_name)
logger.info('--------------- quality factor: {:d} ---------------'.format(quality_factor))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
border = 0


# ----------------------------------------
# load model
# ----------------------------------------

from models.network_fbcnn import FBCNN as net
model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
logger.info('Model path: {:s}'.format(model_path))

test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []
test_results['psnrb'] = []

H_paths = util.get_image_paths(H_path)
for idx, img in enumerate(H_paths):

# ------------------------------------
# (1) img_L
# ------------------------------------
img_name, ext = os.path.splitext(os.path.basename(img))
logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
img_L = util.imread_uint(img, n_channels=n_channels)

if n_channels == 3:
img_L = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR)
_, encimg = cv2.imencode('.jpg', img_L, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
img_L = cv2.imdecode(encimg, 0) if n_channels == 1 else cv2.imdecode(encimg, 3)
if n_channels == 3:
img_L = cv2.cvtColor(img_L, cv2.COLOR_BGR2RGB)
img_L = util.uint2tensor4(img_L)
img_L = img_L.to(device)

# ------------------------------------
# (2) img_E
# ------------------------------------

#img_E,QF = model(img_L, torch.tensor([[0.6]]))
img_E,QF = model(img_L)
QF = 1 - QF
img_E = util.tensor2single(img_E)
img_E = util.single2uint(img_E)
img_H = util.imread_uint(H_paths[idx], n_channels=n_channels).squeeze()
# --------------------------------
# PSNR and SSIM, PSNRB
# --------------------------------

psnr = util.calculate_psnr(img_E, img_H, border=border)
ssim = util.calculate_ssim(img_E, img_H, border=border)
psnrb = util.calculate_psnrb(img_H, img_E, border=border)
test_results['psnr'].append(psnr)
test_results['ssim'].append(ssim)
test_results['psnrb'].append(psnrb)
logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.3f}; PSNRB: {:.2f} dB.'.format(img_name+ext, psnr, ssim, psnrb))
logger.info('predicted quality factor: {:d}'.format(round(float(QF*100))))

util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if show_img else None
util.imsave(img_E, os.path.join(E_path, img_name+'.png'))

ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
ave_psnrb = sum(test_results['psnrb']) / len(test_results['psnrb'])
logger.info(
'Average PSNR/SSIM/PSNRB - {} -: {:.2f}$\\vert${:.4f}$\\vert${:.2f}.'.format(result_name+'_'+str(quality_factor), ave_psnr, ave_ssim, ave_psnrb))


if __name__ == '__main__':
main()
95 changes: 95 additions & 0 deletions main_test_fbcnn_color_real.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os.path
import logging
import numpy as np
from datetime import datetime
from collections import OrderedDict
import torch
import cv2
from utils import utils_logger
from utils import utils_image as util

def main():

testset_name = 'Real' # folder name of real images
n_channels = 3 # set 1 for grayscale image, set 3 for color image
model_name = 'fbcnn_color.pth'
nc = [64,128,256,512]
nb = 4
testsets = 'testsets'
results = 'test_results'

do_flexible_control = True
QF_control = [10,30,50,70,90] # adjust qf as input to provide different results

result_name = testset_name + '_' + model_name[:-4]
L_path = os.path.join(testsets, testset_name)
E_path = os.path.join(results, result_name) # E_path, for Estimated images
util.mkdir(E_path)

model_pool = 'model_zoo' # fixed
model_path = os.path.join(model_pool, model_name)
logger_name = result_name
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
logger = logging.getLogger(logger_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
border = 0


# ----------------------------------------
# load model
# ----------------------------------------

from models.network_fbcnn import FBCNN as net
model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='BR')
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
logger.info('Model path: {:s}'.format(model_path))

test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []
test_results['psnrb'] = []

L_paths = util.get_image_paths(L_path)
for idx, img in enumerate(L_paths):

# ------------------------------------
# (1) img_L
# ------------------------------------
img_name, ext = os.path.splitext(os.path.basename(img))
logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
img_L = util.imread_uint(img, n_channels=n_channels)

img_L = util.uint2tensor4(img_L)
img_L = img_L.to(device)

# ------------------------------------
# (2) img_E
# ------------------------------------

#img_E,QF = model(img_L, torch.tensor([[0.6]]))
img_E,QF = model(img_L)
QF = 1- QF
img_E = util.tensor2single(img_E)
img_E = util.single2uint(img_E)
logger.info('predicted quality factor: {:d}'.format(round(float(QF*100))))
util.imsave(img_E, os.path.join(E_path, img_name+'.png'))

if do_flexible_control:
for QF_set in QF_control:
logger.info('Flexible control by QF = {:d}'.format(QF_set))
# from IPython import embed; embed()
qf_input = torch.tensor([[1-QF_set/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-QF_set/100]])
img_E,QF = model(img_L, qf_input)
QF = 1- QF
img_E = util.tensor2single(img_E)
img_E = util.single2uint(img_E)
util.imsave(img_E, os.path.join(E_path, img_name + '_qf_'+ str(QF_set)+'.png'))


if __name__ == '__main__':
main()
Loading

0 comments on commit 2cd9408

Please sign in to comment.