-
Notifications
You must be signed in to change notification settings - Fork 0
/
Network_res3d.py
143 lines (118 loc) · 7.75 KB
/
Network_res3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import torch.nn as nn
import torch_resizer
from Network import Network
from augmentations import *
from fourier import double_frames_fourier
class Network_residual(Network): # Network with residual after bilinear
def __init__(self, config, device, upsample_scale):
super().__init__(config, device, upsample_scale)
def build_network(self):
"""
take the network flag or parameters from config and create network
:return: net - a torch class/object that can be trained
"""
class NeuralNetwork(nn.Module):
def __init__(self, channels_in, channels_out, config, upsample_scale):
super().__init__()
self.config = config
self.upsample_scale = upsample_scale
# Inputs to 1st hidden layer linear transformation
self.L1 = nn.Conv3d(in_channels=channels_in, out_channels=128, kernel_size=3, padding=1, padding_mode='zeros')
torch.nn.init.normal_(self.L1.weight, mean=0, std=np.sqrt(0.1 / np.prod(self.L1.weight.shape[1:])))
torch.nn.init.normal_(self.L1.bias, mean=0, std=np.sqrt(0.1))
self.L1_b = nn.BatchNorm3d(128)
self.L2 = nn.Conv3d(in_channels=128, out_channels=128, kernel_size=3, padding=1, padding_mode='zeros')
torch.nn.init.normal_(self.L2.weight, mean=0, std=np.sqrt(0.1 / np.prod(self.L2.weight.shape[1:])))
torch.nn.init.normal_(self.L2.bias, mean=0, std=np.sqrt(0.1))
self.L2_b = nn.BatchNorm3d(128)
self.L3 = nn.Conv3d(in_channels=128, out_channels=128, kernel_size=(1, 3, 3), padding=(0, 1, 1), padding_mode='zeros')
torch.nn.init.normal_(self.L3.weight, mean=0, std=np.sqrt(0.1 / np.prod(self.L3.weight.shape[1:])))
torch.nn.init.normal_(self.L3.bias, mean=0, std=np.sqrt(0.1))
self.L3_b = nn.BatchNorm3d(128)
self.L4 = nn.Conv3d(in_channels=128, out_channels=128, kernel_size=(1, 3, 3), padding=(0, 1, 1), padding_mode='zeros')
torch.nn.init.normal_(self.L4.weight, mean=0, std=np.sqrt(0.1 / np.prod(self.L4.weight.shape[1:])))
torch.nn.init.normal_(self.L4.bias, mean=0, std=np.sqrt(0.1))
self.L4_b = nn.BatchNorm3d(128)
self.L5 = nn.Conv3d(in_channels=128, out_channels=128, kernel_size=(1, 3, 3), padding=(0, 1, 1), padding_mode='zeros')
torch.nn.init.normal_(self.L5.weight, mean=0, std=np.sqrt(0.1 / np.prod(self.L5.weight.shape[1:])))
torch.nn.init.normal_(self.L5.bias, mean=0, std=np.sqrt(0.1))
self.L5_b = nn.BatchNorm3d(128)
self.L6 = nn.Conv3d(in_channels=128, out_channels=128, kernel_size=(1, 3, 3), padding=(0, 1, 1), padding_mode='zeros')
torch.nn.init.normal_(self.L6.weight, mean=0, std=np.sqrt(0.1 / np.prod(self.L6.weight.shape[1:])))
torch.nn.init.normal_(self.L6.bias, mean=0, std=np.sqrt(0.1))
self.L6_b = nn.BatchNorm3d(128)
self.L7 = nn.Conv3d(in_channels=128, out_channels=128, kernel_size=(1, 3, 3), padding=(0, 1, 1), padding_mode='zeros')
torch.nn.init.normal_(self.L7.weight, mean=0, std=np.sqrt(0.1 / np.prod(self.L7.weight.shape[1:])))
torch.nn.init.normal_(self.L7.bias, mean=0, std=np.sqrt(0.1))
self.L7_b = nn.BatchNorm3d(128)
self.L8 = nn.Conv3d(in_channels=128, out_channels=channels_out, kernel_size=3, padding=1, padding_mode='zeros')
torch.nn.init.normal_(self.L8.weight, mean=0, std=np.sqrt(0.1 / np.prod(self.L8.weight.shape[1:])))
torch.nn.init.normal_(self.L8.bias, mean=0, std=np.sqrt(0.1))
self.activation = nn.ReLU()
def forward(self, x):
residual_base = self.config["res3d_up_method"] # 'resize' 'duplicate' 'zero_gap'
if residual_base == 'resize':
# self.resizer = torch_resizer.Resizer(x.shape, scale_factor=(1, 1, self.upsample_scale, 1, 1),
# output_shape=[x.shape[0], x.shape[1], x.shape[2] * self.upsample_scale, x.shape[3], x.shape[4]],
# kernel='cubic', antialiasing=True, device='cpu')
# x_upsampled = self.resizer(x)
x_upsampled = double_frames_fourier(x)
x = double_frames_fourier(x)
# x = self.resizer(x)
elif residual_base == 'duplicate':
x_upsampled = torch.nn.functional.interpolate(
x, scale_factor=(self.upsample_scale, 1, 1), mode='trilinear', align_corners=False)
for frame_up_idx in range(x_upsampled.shape[2]):
x_upsampled[:, :, frame_up_idx, :, :] = x[:, :, int(frame_up_idx / self.upsample_scale), :, :]
x_temp = x.detach().clone()
x = torch.nn.functional.interpolate(
x, scale_factor=(self.upsample_scale, 1, 1), mode='trilinear', align_corners=False)
for frame_up_idx in range(x.shape[2]):
x[:, :, frame_up_idx, :, :] = x_temp[:, :, int(frame_up_idx / self.upsample_scale), :, :]
elif residual_base == 'zero_gap':
zero_frame = torch.zeros_like(x[:, :, 0, :, :])
x_upsampled = torch.nn.functional.interpolate(
x, scale_factor=(self.upsample_scale, 1, 1), mode='trilinear', align_corners=False)
for frame_up_idx in range(x_upsampled.shape[2]):
if frame_up_idx % self.upsample_scale == 0: # insert orig frame
x_upsampled[:, :, frame_up_idx, :, :] = x[:, :, int(frame_up_idx / self.upsample_scale), :, :]
else:
x_upsampled[:, :, frame_up_idx, :, :] = zero_frame
x_temp = x.detach().clone()
x = torch.nn.functional.interpolate(
x, scale_factor=(self.upsample_scale, 1, 1), mode='trilinear', align_corners=False)
for frame_up_idx in range(x.shape[2]):
if frame_up_idx % self.upsample_scale == 0: # insert orig frame
x[:, :, frame_up_idx, :, :] = x_temp[:, :, int(frame_up_idx / self.upsample_scale), :, :]
else:
x[:, :, frame_up_idx, :, :] = zero_frame
else:
assert False, f'assertion error in Network_residual forward - residual_base not known: {residual_base}'
# x -> [Batch, Channel, Time, Height, Width]
x = torch.nn.functional.pad(x, [1, 1, 1, 1, 1, 1], mode='replicate')
x1 = self.L1(x)
x1 = self.L1_b(x1)
x2 = nn.ReLU()(x1)
x3 = self.L2(x2)
x3 = self.L2_b(x3)
x4 = nn.ReLU()(x3)
x5 = self.L3(x4)
x5 = self.L3_b(x5)
x6 = nn.ReLU()(x5)
x7 = self.L4(x6)
x7 = self.L4_b(x7)
x8 = nn.ReLU()(x7)
x9 = self.L5(x8)
x9 = self.L5_b(x9)
x10 = nn.ReLU()(x9)
x11 = self.L6(x10)
x11 = self.L6_b(x11)
x12 = nn.ReLU()(x11)
x13 = self.L7(x12)
x13 = self.L7_b(x13)
x14 = nn.ReLU()(x13)
x15 = self.L8(x14)
return x15[:, :, 1:-1, 1:-1, 1:-1] + x_upsampled
net = NeuralNetwork(self.channels_in, self.channels_out, self.config, self.upsample_scale).to(self.device)
return net