forked from Krish2208/IITISoC-Image-Denoising
-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
131 lines (123 loc) · 7.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from distutils.log import debug
from torchvision import transforms
import torch
import torch.nn as nn
from torchmetrics import StructuralSimilarityIndexMeasure
from skimage.metrics import peak_signal_noise_ratio
import numpy as np
import io
from PIL import Image
from flask import Flask, jsonify, request
from flask_cors import CORS
import base64
app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "*"}})
app.config['CORS_HEADERS'] = 'Content-Type'
class ADNet(nn.Module):
def __init__(self, channels, num_of_layers=15):
super(ADNet, self).__init__()
kernel_size = 3
padding = 1
features = 64
groups =1
layers = []
kernel_size1 = 1
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels=channels,out_channels=features,kernel_size=kernel_size,padding=padding,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_2 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=2,groups=groups,bias=False,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_3 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_4 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_5 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=2,groups=groups,bias=False,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_6 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_7 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_8 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_9 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=2,groups=groups,bias=False,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_10 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_11 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_12 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=2,groups=groups,bias=False,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_13 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_14 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=padding,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_15 = nn.Sequential(nn.Conv2d(in_channels=features,out_channels=features,kernel_size=kernel_size,padding=1,groups=groups,bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
self.conv1_16 = nn.Conv2d(in_channels=features,out_channels=3,kernel_size=kernel_size,padding=1,groups=groups,bias=False)
self.conv3 = nn.Conv2d(in_channels=6,out_channels=3,kernel_size=1,stride=1,padding=0,groups=1,bias=True)
self.ReLU = nn.ReLU(inplace=True)
self.Tanh= nn.Tanh()
self.sigmoid = nn.Sigmoid()
for m in self.modules():
if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5)
if isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5)
clip_b = 0.025
w = m.weight.data.shape[0]
for j in range(w):
if m.weight.data[j] >= 0 and m.weight.data[j] < clip_b:
m.weight.data[j] = clip_b
elif m.weight.data[j] > -clip_b and m.weight.data[j] < 0:
m.weight.data[j] = -clip_b
m.running_var.fill_(0.01)
def _make_layers(self, block,features, kernel_size, num_of_layers, padding=1, groups=1, bias=False):
layers = []
for _ in range(num_of_layers):
layers.append(block(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, groups=groups, bias=bias))
return nn.Sequential(*layers)
def forward(self, x):
input = x
x1 = self.conv1_1(x)
x1 = self.conv1_2(x1)
x1 = self.conv1_3(x1)
x1 = self.conv1_4(x1)
x1 = self.conv1_5(x1)
x1 = self.conv1_6(x1)
x1 = self.conv1_7(x1)
x1t = self.conv1_8(x1)
x1 = self.conv1_9(x1t)
x1 = self.conv1_10(x1)
x1 = self.conv1_11(x1)
x1 = self.conv1_12(x1)
x1 = self.conv1_13(x1)
x1 = self.conv1_14(x1)
x1 = self.conv1_15(x1)
x1 = self.conv1_16(x1)
out = torch.cat([x,x1],1)
out= self.Tanh(out)
out = self.conv3(out)
out = out*x1
out2 = x - out
return out2
net = ADNet(channels=3, num_of_layers=17)
device_ids = [0]
model = nn.DataParallel(net, device_ids=device_ids)
model.load_state_dict(torch.load("./Model/original17/model_100.pth", map_location=torch.device('cpu')))
model.eval()
def metrics(image, out):
ssim_module = StructuralSimilarityIndexMeasure(data_range=1)
SSIM = ssim_module(out, image)
Img = image.data.numpy().astype(np.float32)
Iclean = out.data.numpy().astype(np.float32)
PSNR = peak_signal_noise_ratio(Iclean, Img, data_range=1.)
return PSNR, SSIM.item()
def get_prediction(image):
image = Image.open(image)
trans = transforms.ToPILImage()
loader = transforms.Compose([
transforms.ToTensor()
])
image = loader(image).float()
image = image.unsqueeze(0)
out = model(image)
psnr, ssim = metrics(image, out)
out = torch.clamp(out, 0., 1.)
out = out.squeeze(0)
return trans(out), psnr, ssim
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
im = io.BytesIO(base64.b64decode(request.json['x']))
clean_image, psnr, ssim = get_prediction(image=im)
img_io = io.BytesIO()
clean_image.save(img_io, 'JPEG', quality=100)
x = base64.b64encode(img_io.getvalue())
return jsonify({'image': x.decode('utf-8'), 'psnr': psnr, 'ssim': ssim})
if __name__ == '__main__':
app.run(debug=True)