-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels_psloss.py
executable file
·32 lines (29 loc) · 1.31 KB
/
models_psloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
class DnCNN(nn.Module):
def __init__(self, in_channels, out_channels, num_of_layers=17):
super(DnCNN, self).__init__()
kernel_size = 3
padding = 1
features = 64
layers = []
layers.append(nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.ReLU(inplace=True))
for _ in range(num_of_layers-2):
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(features))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
#nn.init.dirac_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
residual = self.dncnn(x)
return residual