-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathganLoss.py
57 lines (48 loc) · 2.21 KB
/
ganLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn as nn
import torch.autograd as autograd
LAMBDA = 0.1
BATCH_SIZE = 12
# def calc_gradient_penalty(netD, real_data, fake_data):
# alpha = torch.rand(BATCH_SIZE, 1)
# alpha = alpha.expand(real_data.size())
# alpha = alpha.cuda()
#
# interpolates = alpha * real_data + ((1 - alpha) * fake_data)
#
# interpolates = interpolates.cuda()
# interpolates = autograd.Variable(interpolates, requires_grad=True)
#
# disc_interpolates = netD(interpolates)
#
# gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
# grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
# create_graph=True, retain_graph=True, only_inputs=True)[0]
#
# gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
#
# return gradient_penalty
def calc_gradient_penalty(netD, real_data, fake_data, LAMBDA):
MSGGan = False
if MSGGan:
alpha = torch.rand(1, 1)
alpha = alpha.cuda() # cuda() #gpu) #if use_cuda else alpha
interpolates = [alpha * rd + ((1 - alpha) * fd) for rd, fd in zip(real_data, fake_data)]
interpolates = [i.cuda() for i in interpolates]
interpolates = [torch.autograd.Variable(i, requires_grad=True) for i in interpolates]
disc_interpolates = netD(interpolates)
else:
alpha = torch.rand(1, 1)
alpha = alpha.expand(real_data.size())
alpha = alpha.cuda() # cuda() #gpu) #if use_cuda else alpha
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
interpolates = interpolates.cuda()#.cuda()
interpolates = torch.autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = netD(interpolates)
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda(),#.cuda(), #if use_cuda else torch.ones(
#disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
#LAMBDA = 1
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty