Skip to content

Commit

Permalink
only calculate self-supervised auxiliary loss when updating discrimin…
Browse files Browse the repository at this point in the history
…ator
  • Loading branch information
lucidrains committed Nov 16, 2020
1 parent 790167a commit e1a14c3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
17 changes: 10 additions & 7 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,15 @@ def __init__(self, D, image_size):
super().__init__()
self.D = D

def forward(self, images, prob = 0., types = [], detach = False):
def forward(self, images, prob = 0., types = [], detach = False, **kwargs):
if random() < prob:
images = random_hflip(images, prob=0.5)
images = DiffAugment(images, types=types)

if detach:
images = images.detach()

return self.D(images)
return self.D(images, **kwargs)

# classes

Expand Down Expand Up @@ -413,7 +413,7 @@ def __init__(
self.decoder1 = SimpleDecoder(chan_in = last_chan, chan_out = init_channel)
self.decoder2 = SimpleDecoder(chan_in = features[-2][-1], chan_out = init_channel)

def forward(self, x):
def forward(self, x, calc_aux_loss = False):
orig_img = x

for layer in self.non_residual_layers:
Expand All @@ -425,7 +425,10 @@ def forward(self, x):
x = layer(x) + residual_layer(x)
layer_outputs.append(x)

out = self.to_logits(x)
out = self.to_logits(x).flatten(1)

if not calc_aux_loss:
return out, None

# self-supervised auto-encoding loss

Expand All @@ -452,7 +455,7 @@ def forward(self, x):

aux_loss = aux_loss1 + aux_loss2

return out.flatten(1), aux_loss
return out, aux_loss

class LightweightGAN(nn.Module):
def __init__(
Expand Down Expand Up @@ -673,11 +676,11 @@ def train(self):
latents = torch.randn(batch_size, latent_dim).cuda(self.rank)

generated_images = G(latents)
fake_output, fake_aux_loss = D_aug(generated_images.clone().detach(), detach = True, **aug_kwargs)
fake_output, fake_aux_loss = D_aug(generated_images.clone().detach(), detach = True, calc_aux_loss = True, **aug_kwargs)

image_batch = next(self.loader).cuda(self.rank)
image_batch.requires_grad_()
real_output, real_aux_loss = D_aug(image_batch, **aug_kwargs)
real_output, real_aux_loss = D_aug(image_batch, calc_aux_loss = True, **aug_kwargs)

real_output_loss = real_output
fake_output_loss = fake_output
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.3'
__version__ = '0.1.4'

0 comments on commit e1a14c3

Please sign in to comment.