forked from DSL-Lab/CPEN455HW-2023W2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
141 lines (115 loc) · 5.43 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm as wn
from utils import *
class nin(nn.Module):
def __init__(self, dim_in, dim_out):
super(nin, self).__init__()
self.lin_a = wn(nn.Linear(dim_in, dim_out))
self.dim_out = dim_out
def forward(self, x):
og_x = x
# assumes pytorch ordering
""" a network in network layer (1x1 CONV) """
# TODO : try with original ordering
x = x.permute(0, 2, 3, 1)
shp = [int(y) for y in x.size()]
out = self.lin_a(x.contiguous().view(shp[0]*shp[1]*shp[2], shp[3]))
shp[-1] = self.dim_out
out = out.view(shp)
return out.permute(0, 3, 1, 2)
class down_shifted_conv2d(nn.Module):
def __init__(self, num_filters_in, num_filters_out, filter_size=(2,3), stride=(1,1),
shift_output_down=False, norm='weight_norm'):
super(down_shifted_conv2d, self).__init__()
assert norm in [None, 'batch_norm', 'weight_norm']
self.conv = nn.Conv2d(num_filters_in, num_filters_out, filter_size, stride)
self.shift_output_down = shift_output_down
self.norm = norm
self.pad = nn.ZeroPad2d((int((filter_size[1] - 1) / 2), # pad left
int((filter_size[1] - 1) / 2), # pad right
filter_size[0] - 1, # pad top
0) ) # pad down
if norm == 'weight_norm':
self.conv = wn(self.conv)
elif norm == 'batch_norm':
self.bn = nn.BatchNorm2d(num_filters_out)
if shift_output_down :
self.down_shift = lambda x : down_shift(x, pad=nn.ZeroPad2d((0, 0, 1, 0)))
def forward(self, x):
x = self.pad(x)
x = self.conv(x)
x = self.bn(x) if self.norm == 'batch_norm' else x
return self.down_shift(x) if self.shift_output_down else x
class down_shifted_deconv2d(nn.Module):
def __init__(self, num_filters_in, num_filters_out, filter_size=(2,3), stride=(1,1)):
super(down_shifted_deconv2d, self).__init__()
self.deconv = wn(nn.ConvTranspose2d(num_filters_in, num_filters_out, filter_size, stride,
output_padding=1))
self.filter_size = filter_size
self.stride = stride
def forward(self, x):
x = self.deconv(x)
xs = [int(y) for y in x.size()]
return x[:, :, :(xs[2] - self.filter_size[0] + 1),
int((self.filter_size[1] - 1) / 2):(xs[3] - int((self.filter_size[1] - 1) / 2))]
class down_right_shifted_conv2d(nn.Module):
def __init__(self, num_filters_in, num_filters_out, filter_size=(2,2), stride=(1,1),
shift_output_right=False, norm='weight_norm'):
super(down_right_shifted_conv2d, self).__init__()
assert norm in [None, 'batch_norm', 'weight_norm']
self.pad = nn.ZeroPad2d((filter_size[1] - 1, 0, filter_size[0] - 1, 0))
self.conv = nn.Conv2d(num_filters_in, num_filters_out, filter_size, stride=stride)
self.shift_output_right = shift_output_right
self.norm = norm
if norm == 'weight_norm':
self.conv = wn(self.conv)
elif norm == 'batch_norm':
self.bn = nn.BatchNorm2d(num_filters_out)
if shift_output_right :
self.right_shift = lambda x : right_shift(x, pad=nn.ZeroPad2d((1, 0, 0, 0)))
def forward(self, x):
x = self.pad(x)
x = self.conv(x)
x = self.bn(x) if self.norm == 'batch_norm' else x
return self.right_shift(x) if self.shift_output_right else x
class down_right_shifted_deconv2d(nn.Module):
def __init__(self, num_filters_in, num_filters_out, filter_size=(2,2), stride=(1,1),
shift_output_right=False):
super(down_right_shifted_deconv2d, self).__init__()
self.deconv = wn(nn.ConvTranspose2d(num_filters_in, num_filters_out, filter_size,
stride, output_padding=1))
self.filter_size = filter_size
self.stride = stride
def forward(self, x):
x = self.deconv(x)
xs = [int(y) for y in x.size()]
x = x[:, :, :(xs[2] - self.filter_size[0] + 1):, :(xs[3] - self.filter_size[1] + 1)]
return x
'''
skip connection parameter : 0 = no skip connection
1 = skip connection where skip input size === input size
2 = skip connection where skip input size === 2 * input size
'''
class gated_resnet(nn.Module):
def __init__(self, num_filters, conv_op, nonlinearity=concat_elu, skip_connection=0):
super(gated_resnet, self).__init__()
self.skip_connection = skip_connection
self.nonlinearity = nonlinearity
self.conv_input = conv_op(2 * num_filters, num_filters) # cuz of concat elu
if skip_connection != 0 :
self.nin_skip = nin(2 * skip_connection * num_filters, num_filters)
self.dropout = nn.Dropout2d(0.5)
self.conv_out = conv_op(2 * num_filters, 2 * num_filters)
def forward(self, og_x, a=None):
x = self.conv_input(self.nonlinearity(og_x))
if a is not None :
x += self.nin_skip(self.nonlinearity(a))
x = self.nonlinearity(x)
x = self.dropout(x)
x = self.conv_out(x)
a, b = torch.chunk(x, 2, dim=1)
c3 = a * F.sigmoid(b)
return og_x + c3