Skip to content

Commit

Permalink
Keep only one model class (will be parameterized in the future).
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandru-dinu committed Jul 12, 2021
1 parent 64f5232 commit b248583
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 606 deletions.
11 changes: 8 additions & 3 deletions src/models/cae_32x32x32_zero_pad_bin.py → src/cae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ class CAE(nn.Module):
Latent representation: 32x32x32 bits per patch => 240KB per image (for 720p)
"""

# TODO: modularize

def __init__(self):
super(CAE, self).__init__()

self.encoded = None
self.encoded_shape = (32, 32, 32)

# ENCODER

Expand Down Expand Up @@ -165,6 +167,9 @@ def __init__(self):
)

def forward(self, x):
return self.decode(self.encode(x))

def encode(self, x):
ec1 = self.e_conv_1(x)
ec2 = self.e_conv_2(ec1)
eblock1 = self.e_block_1(ec2) + ec2
Expand All @@ -180,10 +185,10 @@ def forward(self, x):
eps[rand <= prob] = (1 - ec3)[rand <= prob]
eps[rand > prob] = (-ec3 - 1)[rand > prob]

# encoded tensor
# encoded tensor (in latent space)
self.encoded = 0.5 * (ec3 + eps + 1) # (-1|1) -> (0|1)

return self.decode(self.encoded)
return self.encoded

def decode(self, encoded):
y = encoded * 2.0 - 1 # (0|1) -> (-1|1)
Expand Down
195 changes: 0 additions & 195 deletions src/models/cae_16x16x16_zero_pad_bin.py

This file was deleted.

Loading

0 comments on commit b248583

Please sign in to comment.