-
Notifications
You must be signed in to change notification settings - Fork 0
/
generator_model.py
110 lines (82 loc) · 3.31 KB
/
generator_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Generator model for CycleGAN
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-11-05: Initial coding
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
"""
from common_blocks import ConvBlock, ResidualBlock, TransposeConvBlock
from torch import nn
import torch
class GeneratorEncoder(nn.Module):
def __init__(self, in_channels=3, out_channels=64):
super().__init__()
self.encoder = nn.Sequential(
ConvBlock(in_channels, out_channels, kernel_size=3, padding=1, stride=1, is_relu=True),
ConvBlock(out_channels, out_channels*2, kernel_size=3, padding=1, stride=1, is_relu=True),
ConvBlock(out_channels*2, out_channels*4, kernel_size=3, padding=1, stride=1, is_relu=True),
ResidualBlock(out_channels*4),
ResidualBlock(out_channels*4)
)
def forward(self, x):
x = self.encoder(x)
return x
class GeneratorDecoder(nn.Module):
def __init__(self, in_channels=3, out_channels=64):
super().__init__()
self.decoder = nn.Sequential(
ResidualBlock(out_channels*4, padding_mode="zeros"),
ResidualBlock(out_channels*4, padding_mode="zeros"),
TransposeConvBlock(out_channels*4, out_channels*2,kernel_size=3,padding=1, stride=1, is_relu=True, padding_mode="zeros"),
TransposeConvBlock(out_channels*2, out_channels,kernel_size=3,padding=1, stride=1, is_relu=True, padding_mode="zeros"),
TransposeConvBlock(out_channels, out_channels=in_channels,kernel_size=3,padding=1, stride=1, is_relu=True, padding_mode="zeros"),
)
def forward(self, x):
x = self.decoder(x)
return x
# Tests
def test_generator_encoder():
images_per_batch = 5
img_channels = 3
img_size = 128
x = torch.randn((images_per_batch, img_channels, img_size, img_size))
gen_encoder = GeneratorEncoder(in_channels=3,out_channels=64)
print("input : ", x.shape)
print("encoder of generator output : " , gen_encoder(x).shape)
def test_generator_decoder():
images_per_batch = 5
img_channels = 3
img_size = 128
x = torch.randn((images_per_batch, 256, img_size, img_size))
gen_decoder = GeneratorDecoder()
print("input : ", x.shape)
print("decoder of generator output : " , gen_decoder(x).shape)
class Generator(nn.Module):
def __init__(self, in_channels=3, out_channels=64):
super().__init__()
self.encoder = GeneratorEncoder(in_channels=in_channels, out_channels=out_channels)
self.decoder = GeneratorDecoder(in_channels=in_channels, out_channels=out_channels)
def forward(self,x):
x = self.encoder(x)
x = self.decoder(x)
return x
def debug(self,x):
print("Input : " ,x.shape)
x = self.encoder(x)
print("After encoder : ", x.shape)
x = self.decoder(x)
print("After decoder : ", x.shape)
return x
def test_generator():
images_per_batch = 5
img_channels = 4
img_size = 128
num_features = 50 # num of features in the feature vector (number of channels in an image in the encoder output)
x = torch.randn((images_per_batch, img_channels, img_size, img_size))
gen = Generator(in_channels=img_channels, out_channels=num_features)
output = gen(x)
gen.debug(x)
print("generator output :", output.shape)
if __name__ == "__main__":
test_generator_encoder()
test_generator_decoder()
test_generator()