Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An example of use aw_loss in cDCGAN. #6

Open
qingyuany opened this issue Oct 24, 2023 · 0 comments
Open

An example of use aw_loss in cDCGAN. #6

qingyuany opened this issue Oct 24, 2023 · 0 comments

Comments

@qingyuany
Copy link

I am trying to use aw loss in my cDCGAN. Unexpected issues occurred during code compilation. I don't quite understand. I hope to receive your help. Alternatively, upload an application instance of 'awloss'.
My code is as follows:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dset
from tqdm.autonotebook import tqdm
from torchvision.utils import save_image
from matplotlib import pyplot as plt
import numpy as np
from torch.autograd import Variable
from new_loss import aw_method

num_epochs = 1000
betas = (0.5, 0.999)
lr = 0.0002 # 1e-5

batch_size = 64
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
z_dim = 100 # latent Space
c_dim = 1 # Image Channel
label_dim = 10 # label
image_size = 32
beta1 = 0.5
PATH = "./generate/"

generator_out_linear = 100 # l rumore che andrà al generatore sarà 100+generator_out_linear, deve essere >=10

MNIST dataset

transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
# transforms.Normalize((0.5,),(0.5,)),
])

train_set = dset.MNIST(root='./mnist_data/',
train=True,
transform=transform,
download=True)

test_set = dset.MNIST(root='./mnist_data/',
train=False,
transform=transform,
download=False)

train_loader = torch.utils.data.DataLoader(
dataset=train_set,
batch_size=batch_size,
shuffle=True,
drop_last=True
)

test_loader = torch.utils.data.DataLoader(
dataset=test_set,
batch_size=batch_size,
shuffle=False,
drop_last=True
)

Generator model

class Generator(nn.Module):
def init(self, z_dim, label_dim):
super(Generator, self).init()

    self.ylabel = nn.Sequential(
        nn.Linear(10, generator_out_linear),  # 将10通道onehot标签通过线性变换成100通道
        nn.ReLU(True)
    )

    self.concat = nn.Sequential(
        # 噪声将变为 z_dim + 条件噪声, 100+100通道
        nn.ConvTranspose2d(z_dim + generator_out_linear, 64 * 4, 4, 1, 0, bias=False),
        nn.BatchNorm2d(64 * 4),
        nn.ReLU(True),

        nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(64 * 2),
        nn.ReLU(True),

        nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(True),

        nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
        nn.Tanh()
    )

def forward(self, x, y):
    y = y.reshape(-1, 10)
    y = self.ylabel(y)
    y = y.reshape(-1, generator_out_linear, 1, 1)  # 增加维度,与输入噪声维度达成一致

    out = torch.cat([x, y], dim=1)  # 将噪声与标签在通道维度上融合在一起
    out = out.view(-1, 100 + generator_out_linear, 1, 1)

    out = self.concat(out)  # 获得生成器网络输出数据

    return out

Discriminator model

class Discriminator(nn.Module):
def init(self, nc=1, label_dim=10):
super(Discriminator, self).init()

    self.ylabel = nn.Sequential(
        nn.Linear(10, 32 * 32 * 1),  # 将one-hot编码标签转换成与图片形式相同的维度 32*32*1  1表示通道数
        nn.ReLU(True)
    )

    self.concate = nn.Sequential(
        # 将图片与标签在通道维度上拼接在一起  输入结构是 (-1,1+1,32,32)
        nn.Conv2d(nc + 1, 64, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(64 * 2),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(64 * 4),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(64 * 4, 1, 4, 1, 0, bias=False),
        nn.Sigmoid()
    )

def forward(self, x, y):
    y = y.reshape(batch_size, 10)
    y = self.ylabel(y)
    y = y.view(-1, 1, image_size, image_size)

    out = torch.cat([x, y], dim=1)
    out = self.concate(out)

    return out

def weights_init(m):
# 用于初始化神经网络模型的权重和偏差
classname = m.class.name
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02) # 初始化卷积层,将权重参数的值从均值0.0、标准差0.02的正态分布中随机抽取
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)

创建 aw_method 实例,可以根据需要自定义超参数

aw = aw_method(alpha1=0.5, alpha2=0.75, delta=0.05, epsilon=0.05, normalized_aw=True)

def train_GAN(G, D, G_opt, D_opt, dataset):
for i, (data, label) in tqdm(enumerate(dataset)):
'''
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(data, padding=2).cpu(),(1,2,0)))
plt.show()
'''

    # Train with all-real batch
    D_opt.zero_grad()

    x_real = data.to(device)
    y_real = torch.ones(batch_size, ).to(device)
    label_onehot = onehot[label]
    # y_real_predict = D(x_real, label_onehot).squeeze()  # (-1, 1, 1, 1) -> (-1, )  squeeze() 用于去除维度为1的通道

    # d_real_loss = criterion(y_real_predict, y_real)
    # d_real_loss.backward()
    real_validity = D(x_real, label_onehot).squeeze()
    Dloss_real = criterion(real_validity, y_real)

    # Train with all-fake batch
    noise = torch.randn(batch_size, z_dim, 1, 1, device=device)  # 因为其他维度大小都是1,所以在通道维度拼接,其实就相当于延长数据长度
    noise_label = (torch.rand(batch_size, 1) * label_dim).type(torch.LongTensor).squeeze()  # 生成0-9之内的随机整数
    # print(noise_label)
    noise_label_onehot = onehot[noise_label].to(device)  # Genera label in modo casuale (-1,)
    x_fake = G(noise, noise_label_onehot)  # Genera immagini false
    y_fake = torch.zeros(batch_size, ).to(device)  # Assegna label 0
    # y_fake_predict = D(x_fake, noise_label_onehot).squeeze()
    # d_fake_loss = criterion(y_fake_predict, y_fake)
    # d_fake_loss.backward()
    fake_validity = D(x_fake, noise_label_onehot).squeeze()
    Dloss_fake = criterion(fake_validity, y_fake)

    # 计算 aw_loss
    aw_loss = aw.aw_loss(Dloss_real, Dloss_fake, D_opt, D, real_validity, fake_validity)

    # 反向传播和优化
    aw_loss.backward(retain_graph=True)
    # D_opt.step()



    # (2) Update G network: maximize log(D(G(z)))
    G_opt.zero_grad()

    noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
    noise_label = (torch.rand(batch_size, 1) * label_dim).type(torch.LongTensor).squeeze()
    noise_label_onehot = onehot[noise_label].to(device)  # Genera label in modo casuale (-1,)
    x_fake = G(noise, noise_label_onehot)
    # y_fake = torch.ones(batch_size, ).to(device)    # Il y_fake qui è lo stesso di y_real sopra, entrambi sono 1
    # y_fake_predict = D(x_fake, noise_label_onehot).squeeze()
    # g_loss = criterion(y_fake_predict, y_real)  # Usa direttamente y_real per essere più intuitivo

    fake_validity = D(x_fake, noise_label_onehot).squeeze()
    g_loss = criterion(fake_validity, y_real)

    # 反向传播和优化
    g_loss.backward()
    G_opt.step()

    err_D = aw_loss.item()
    err_G = g_loss.item()
    '''
    if i%50 == 0:
        with torch.no_grad():
            out_imgs = G(fixed_noise.to(device), fixed_label.to(device))
        save_image(out_imgs,f"{PATH}{i}.png", nrow = 10) #aggiungi percorso: "path/iterazione_classe.png" es "pippo/20000_3.png"
    '''
return err_D, err_G

Models

D = Discriminator(c_dim, label_dim).to(device)
D.apply(weights_init)

G = Generator(z_dim, label_dim).to(device)
G.apply(weights_init)

D_opt = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999)) # , betas=(beta1, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999)) # , betas=(beta1, 0.999))

Loss function

criterion = torch.nn.BCELoss()

创建一个固定噪声,用于测试

fixed_noise = torch.randn(100, 100)
fixed_noise = fixed_noise.reshape(100, 100, 1, 1)

创建一个固定标签,

labels = torch.LongTensor([i for i in range(10) for _ in range(
10)]).cuda()

labels = 00000000001111111111222222222233333333334444444444555555555566666666667777777777788888888889999999999

labels = labels.reshape(100, 1)
one_hot = nn.functional.one_hot(labels, num_classes=10) # i labels codificato in one_hot
fixed_label = one_hot.reshape(100, 10, 1, 1).float()

我在 onehot 中为数字 0 到 9 创建了自己的转换器:

onehot_before_cod = torch.LongTensor([i for i in range(10)]).cuda() # 0123456789
onehot = nn.functional.one_hot(onehot_before_cod, num_classes=10)
onehot = onehot.reshape(10, 10, 1, 1).float()

D_loss = []
G_loss = []

for epoch in tqdm(range(num_epochs)):
D_losses = []
G_losses = []
if epoch == 5 or epoch == 10:
G_opt.param_groups[0]['lr'] /= 2
D_opt.param_groups[0]['lr'] /= 2

# training
err_D, err_G = train_GAN(G, D, G_opt, D_opt, train_loader)

D_loss.append(err_D)
G_loss.append(err_G)

# test
if epoch % 1 == 0 or epoch + 1 == num_epochs:
    with torch.no_grad():
        out_imgs = G(fixed_noise.to(device), fixed_label.to(device))

    save_image(out_imgs, f"{PATH}{epoch}.png",
               nrow=10)  # aggiungi percorso: "path/iterazione_classe.png" es "pippo/20000_3.png"

    # salva i modelli
    torch.save(D.state_dict(), f'{PATH}discriminator_cDCGAN_{epoch}.pth')
    torch.save(G.state_dict(), f'{PATH}generator_cDCGAN_{epoch}.pth')

Unexpected issues occurred:
Traceback (most recent call last):
File "D:\python_parctice\pt\个人练习\数据预处理模板\不平衡图像数据的条件生成对抗网络生成\cDCGAN-main\cDCGAN MNIST.py", line 267, in
err_D, err_G = train_GAN(G, D, G_opt, D_opt, train_loader)
File "D:\python_parctice\pt\个人练习\数据预处理模板\不平衡图像数据的条件生成对抗网络生成\cDCGAN-main\cDCGAN MNIST.py", line 192, in train_GAN
aw_loss.backward(retain_graph=True)
File "D:\Anaconda\install\envs\pt\lib\site-packages\torch_tensor.py", line 487, in backward
torch.autograd.backward(
File "D:\Anaconda\install\envs\pt\lib\site-packages\torch\autograd_init_.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant