Skip to content

Commit

Permalink
Revert "Use LeakyReLU in descriminator feature extraction"
Browse files Browse the repository at this point in the history
This reverts commit 6cd17d8.
  • Loading branch information
moto-hellomoto-ai committed May 6, 2019
1 parent 852ea69 commit dd01034
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions sp_vae_gan/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class EncoderBlock(nn.Sequential):
Expand Down Expand Up @@ -91,32 +92,24 @@ def __init__(self, feat_size):
self.convs = nn.Sequential(
nn.ReflectionPad2d(2),
nn.Conv2d(3, 32, kernel_size=5),
nn.LeakyReLU(),
#################
nn.ReflectionPad2d(2),
nn.Conv2d(32, 128, kernel_size=5, stride=2),
nn.LeakyReLU(),
#################
nn.ReflectionPad2d(2),
nn.Conv2d(128, 256, kernel_size=5, stride=2),
nn.LeakyReLU(),
#################
nn.ReflectionPad2d(2),
nn.Conv2d(256, 256, kernel_size=5, stride=2),
nn.LeakyReLU(),
#################
)
n_feat = self.feat_size[0] * self.feat_size[1] * 256
self.fc = nn.Sequential(
nn.Linear(in_features=n_feat, out_features=512),
nn.LeakyReLU(),
nn.ReLU(inplace=True),
nn.Linear(in_features=512, out_features=1),
nn.Sigmoid()
)

def forward(self, x):
x = self.convs(x)
x_feats = x
x_feats = self.convs(x)
x = F.relu(x_feats)
x = x.view(len(x), -1)
x = self.fc(x)
return x, x_feats
Expand Down

0 comments on commit dd01034

Please sign in to comment.