diff --git a/src/models/cae_32x32x32_zero_pad_bin.py b/src/cae.py similarity index 96% rename from src/models/cae_32x32x32_zero_pad_bin.py rename to src/cae.py index f306677..a080bd9 100644 --- a/src/models/cae_32x32x32_zero_pad_bin.py +++ b/src/cae.py @@ -10,10 +10,12 @@ class CAE(nn.Module): Latent representation: 32x32x32 bits per patch => 240KB per image (for 720p) """ + # TODO: modularize + def __init__(self): super(CAE, self).__init__() - self.encoded = None + self.encoded_shape = (32, 32, 32) # ENCODER @@ -165,6 +167,9 @@ def __init__(self): ) def forward(self, x): + return self.decode(self.encode(x)) + + def encode(self, x): ec1 = self.e_conv_1(x) ec2 = self.e_conv_2(ec1) eblock1 = self.e_block_1(ec2) + ec2 @@ -180,10 +185,10 @@ def forward(self, x): eps[rand <= prob] = (1 - ec3)[rand <= prob] eps[rand > prob] = (-ec3 - 1)[rand > prob] - # encoded tensor + # encoded tensor (in latent space) self.encoded = 0.5 * (ec3 + eps + 1) # (-1|1) -> (0|1) - return self.decode(self.encoded) + return self.encoded def decode(self, encoded): y = encoded * 2.0 - 1 # (0|1) -> (-1|1) diff --git a/src/models/cae_16x16x16_zero_pad_bin.py b/src/models/cae_16x16x16_zero_pad_bin.py deleted file mode 100644 index 61ed935..0000000 --- a/src/models/cae_16x16x16_zero_pad_bin.py +++ /dev/null @@ -1,195 +0,0 @@ -import torch -import torch.nn as nn - - -class CAE(nn.Module): - """ - This AE module will be fed 3x128x128 patches from the original image - Shapes are (batch_size, channels, height, width) - - Latent representation: 16x16x16 bits per patch => 30KB per image (for 720p) - """ - - def __init__(self): - super(CAE, self).__init__() - - self.encoded = None - - # ENCODER - - # 64x64x64 - self.e_conv_1 = nn.Sequential( - nn.ZeroPad2d((1, 2, 1, 2)), - nn.Conv2d( - in_channels=3, out_channels=64, kernel_size=(5, 5), stride=(2, 2) - ), - nn.LeakyReLU(), - ) - - # 128x32x32 - self.e_conv_2 = nn.Sequential( - nn.ZeroPad2d((1, 2, 1, 2)), - nn.Conv2d( - in_channels=64, out_channels=128, kernel_size=(5, 5), stride=(2, 2) - ), - nn.LeakyReLU(), - ) - - # 128x32x32 - self.e_block_1 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x32x32 - self.e_block_2 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x32x32 - self.e_block_3 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 16x16x16 - self.e_conv_3 = nn.Sequential( - nn.ZeroPad2d((1, 2, 1, 2)), - nn.Conv2d( - in_channels=128, out_channels=16, kernel_size=(5, 5), stride=(2, 2) - ), - nn.Tanh(), - ) - - # DECODER - - # 128x32x32 - self.d_up_conv_1 = nn.Sequential( - nn.Conv2d( - in_channels=16, out_channels=64, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.ConvTranspose2d( - in_channels=64, out_channels=128, kernel_size=(2, 2), stride=(2, 2) - ), - ) - - # 128x32x32 - self.d_block_1 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x32x32 - self.d_block_2 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x32x32 - self.d_block_3 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 256x64x64 - self.d_up_conv_2 = nn.Sequential( - nn.Conv2d( - in_channels=128, out_channels=32, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.ConvTranspose2d( - in_channels=32, out_channels=256, kernel_size=(2, 2), stride=(2, 2) - ), - ) - - # 3x128x128 - self.d_up_conv_3 = nn.Sequential( - nn.Conv2d( - in_channels=256, out_channels=16, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.ConvTranspose2d( - in_channels=16, out_channels=3, kernel_size=(2, 2), stride=(2, 2) - ), - nn.Tanh(), - ) - - def forward(self, x): - ec1 = self.e_conv_1(x) - ec2 = self.e_conv_2(ec1) - eblock1 = self.e_block_1(ec2) + ec2 - eblock2 = self.e_block_2(eblock1) + eblock1 - eblock3 = self.e_block_3(eblock2) + eblock2 - ec3 = self.e_conv_3(eblock3) # in [-1, 1] from tanh activation - - # stochastic binarization - with torch.no_grad(): - rand = torch.rand(ec3.shape).cuda() - prob = (1 + ec3) / 2 - eps = torch.zeros(ec3.shape).cuda() - eps[rand <= prob] = (1 - ec3)[rand <= prob] - eps[rand > prob] = (-ec3 - 1)[rand > prob] - - # encoded tensor - self.encoded = 0.5 * (ec3 + eps + 1) # (-1|1) -> (0|1) - - return self.decode(self.encoded) - - def decode(self, encoded): - y = encoded * 2.0 - 1 # (0|1) -> (-1|1) - - uc1 = self.d_up_conv_1(y) - dblock1 = self.d_block_1(uc1) + uc1 - dblock2 = self.d_block_2(dblock1) + dblock1 - dblock3 = self.d_block_3(dblock2) + dblock2 - uc2 = self.d_up_conv_2(dblock3) - dec = self.d_up_conv_3(uc2) - - return dec diff --git a/src/models/cae_16x8x8_refl_pad_bin.py b/src/models/cae_16x8x8_refl_pad_bin.py deleted file mode 100644 index 1f92674..0000000 --- a/src/models/cae_16x8x8_refl_pad_bin.py +++ /dev/null @@ -1,204 +0,0 @@ -import torch -import torch.nn as nn - - -# will see 3x128x128 patches -class CAE(nn.Module): - """ - This AE module will be fed 3x128x128 patches from the original image - - Latent representation: 16x8x8 bits per patch => 7.5KB per image (for 720p) - """ - - def __init__(self): - super(CAE, self).__init__() - - self.encoded = None - - # ENCODER - - # 64x64x64 - self.e_conv_1 = nn.Sequential( - nn.ReflectionPad2d((1, 2, 1, 2)), - nn.Conv2d( - in_channels=3, out_channels=64, kernel_size=(5, 5), stride=(2, 2) - ), - nn.LeakyReLU(), - ) - - # 128x32x32 - self.e_conv_2 = nn.Sequential( - nn.ReflectionPad2d((1, 2, 1, 2)), - nn.Conv2d( - in_channels=64, out_channels=128, kernel_size=(5, 5), stride=(2, 2) - ), - nn.LeakyReLU(), - ) - - # 128x32x32 - self.e_block_1 = nn.Sequential( - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x16x16 - self.e_pool_1 = nn.Sequential(nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))) - - # 128x16x16 - self.e_block_2 = nn.Sequential( - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x16x16 - self.e_block_3 = nn.Sequential( - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 16x8x8 - self.e_conv_3 = nn.Sequential( - nn.ReflectionPad2d((1, 2, 1, 2)), - nn.Conv2d( - in_channels=128, out_channels=16, kernel_size=(5, 5), stride=(2, 2) - ), - nn.Tanh(), - ) - - # DECODER - - # 128x16x16 - self.d_up_conv_1 = nn.Sequential( - nn.Conv2d( - in_channels=16, out_channels=64, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.ConvTranspose2d( - in_channels=64, out_channels=128, kernel_size=(2, 2), stride=(2, 2) - ), - ) - - # 128x16x16 - self.d_block_1 = nn.Sequential( - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x32x32 - self.d_up_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest")) - - # 128x32x32 - self.d_block_2 = nn.Sequential( - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x32x32 - self.d_block_3 = nn.Sequential( - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 256x64x64 - self.d_up_conv_2 = nn.Sequential( - nn.Conv2d( - in_channels=128, out_channels=32, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.ConvTranspose2d( - in_channels=32, out_channels=256, kernel_size=(2, 2), stride=(2, 2) - ), - ) - - # 3x128x128 - self.d_up_conv_3 = nn.Sequential( - nn.Conv2d( - in_channels=256, out_channels=16, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ReflectionPad2d((1, 1, 1, 1)), - nn.ConvTranspose2d( - in_channels=16, out_channels=3, kernel_size=(2, 2), stride=(2, 2) - ), - nn.Tanh(), - ) - - def forward(self, x): - # ENCODE - ec1 = self.e_conv_1(x) - ec2 = self.e_conv_2(ec1) - eblock1 = self.e_block_1(ec2) - eblock1 = self.e_pool_1(ec2 + eblock1) - eblock2 = self.e_block_2(eblock1) + eblock1 - eblock3 = self.e_block_3(eblock2) + eblock2 - ec3 = self.e_conv_3(eblock3) # in [-1, 1] from tanh activation - - # stochastic binarization - with torch.no_grad(): - rand = torch.rand(ec3.shape).cuda() - prob = (1 + ec3) / 2 - eps = torch.zeros(ec3.shape).cuda() - eps[rand <= prob] = (1 - ec3)[rand <= prob] - eps[rand > prob] = (-ec3 - 1)[rand > prob] - - # encoded tensor - self.encoded = 0.5 * (ec3 + eps + 1) # (-1|1) -> (0|1) - - return self.decode(self.encoded) - - def decode(self, enc): - y = enc * 2.0 - 1 # (0|1) -> (-1, 1) - - uc1 = self.d_up_conv_1(y) - dblock1 = self.d_block_1(uc1) + uc1 - dup1 = self.d_up_1(dblock1) - dblock2 = self.d_block_2(dup1) + dup1 - dblock3 = self.d_block_3(dblock2) + dblock2 - uc2 = self.d_up_conv_2(dblock3) - dec = self.d_up_conv_3(uc2) - - return dec diff --git a/src/models/cae_16x8x8_zero_pad_bin.py b/src/models/cae_16x8x8_zero_pad_bin.py deleted file mode 100644 index 897bdbf..0000000 --- a/src/models/cae_16x8x8_zero_pad_bin.py +++ /dev/null @@ -1,204 +0,0 @@ -import torch -import torch.nn as nn - - -class CAE(nn.Module): - """ - This AE module will be fed 3x128x128 patches from the original image - Shapes are (batch_size, channels, height, width) - - Latent representation: 16x8x8 bits per patch => 7.5KB per image (for 720p) - """ - - def __init__(self): - super(CAE, self).__init__() - - self.encoded = None - - # ENCODER - - # 64x64x64 - self.e_conv_1 = nn.Sequential( - nn.ZeroPad2d((1, 2, 1, 2)), - nn.Conv2d( - in_channels=3, out_channels=64, kernel_size=(5, 5), stride=(2, 2) - ), - nn.LeakyReLU(), - ) - - # 128x32x32 - self.e_conv_2 = nn.Sequential( - nn.ZeroPad2d((1, 2, 1, 2)), - nn.Conv2d( - in_channels=64, out_channels=128, kernel_size=(5, 5), stride=(2, 2) - ), - nn.LeakyReLU(), - ) - - # 128x32x32 - self.e_block_1 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x16x16 - self.e_pool_1 = nn.Sequential(nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))) - - # 128x16x16 - self.e_block_2 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x16x16 - self.e_block_3 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 16x8x8 - self.e_conv_3 = nn.Sequential( - nn.ZeroPad2d((1, 2, 1, 2)), - nn.Conv2d( - in_channels=128, out_channels=16, kernel_size=(5, 5), stride=(2, 2) - ), - nn.Tanh(), - ) - - # DECODER - - # 128x16x16 - self.d_up_conv_1 = nn.Sequential( - nn.Conv2d( - in_channels=16, out_channels=64, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.ConvTranspose2d( - in_channels=64, out_channels=128, kernel_size=(2, 2), stride=(2, 2) - ), - ) - - # 128x16x16 - self.d_block_1 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x32x32 - self.d_up_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest")) - - # 128x32x32 - self.d_block_2 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 128x32x32 - self.d_block_3 = nn.Sequential( - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.Conv2d( - in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1) - ), - ) - - # 256x64x64 - self.d_up_conv_2 = nn.Sequential( - nn.Conv2d( - in_channels=128, out_channels=32, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.ConvTranspose2d( - in_channels=32, out_channels=256, kernel_size=(2, 2), stride=(2, 2) - ), - ) - - # 3x128x128 - self.d_up_conv_3 = nn.Sequential( - nn.Conv2d( - in_channels=256, out_channels=16, kernel_size=(3, 3), stride=(1, 1) - ), - nn.LeakyReLU(), - nn.ZeroPad2d((1, 1, 1, 1)), - nn.ConvTranspose2d( - in_channels=16, out_channels=3, kernel_size=(2, 2), stride=(2, 2) - ), - nn.Tanh(), - ) - - def forward(self, x): - # ENCODE - ec1 = self.e_conv_1(x) - ec2 = self.e_conv_2(ec1) - eblock1 = self.e_block_1(ec2) - eblock1 = self.e_pool_1(ec2 + eblock1) - eblock2 = self.e_block_2(eblock1) + eblock1 - eblock3 = self.e_block_3(eblock2) + eblock2 - ec3 = self.e_conv_3(eblock3) # in [-1, 1] from tanh activation - - # stochastic binarization - with torch.no_grad(): - rand = torch.rand(ec3.shape).cuda() - prob = (1 + ec3) / 2 - eps = torch.zeros(ec3.shape).cuda() - eps[rand <= prob] = (1 - ec3)[rand <= prob] - eps[rand > prob] = (-ec3 - 1)[rand > prob] - - # encoded tensor - self.encoded = 0.5 * (ec3 + eps + 1) # (-1|1) -> (0|1) - - return self.decode(self.encoded) - - def decode(self, enc): - y = enc * 2.0 - 1 # (0|1) -> (-1, 1) - - uc1 = self.d_up_conv_1(y) - dblock1 = self.d_block_1(uc1) + uc1 - dup1 = self.d_up_1(dblock1) - dblock2 = self.d_block_2(dup1) + dup1 - dblock3 = self.d_block_3(dblock2) + dblock2 - uc2 = self.d_up_conv_2(dblock3) - dec = self.d_up_conv_3(uc2) - - return dec