-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcGAN.py
118 lines (100 loc) · 4.8 KB
/
cGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import os
import random
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
from torchvision import datasets
from torchvision.datasets.utils import download_file_from_google_drive
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
#reshape layer
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view(self.shape)
preferred_dataset = 'lfwcrop'
# Generator Class
class Generator(nn.Module):
def __init__(self, n_classes,latentdim, batch_size, dataset_name, img_shape):
super(Generator, self).__init__()
self.label_embed = nn.Embedding(n_classes, n_classes) # mi crea un dizionario di 10 elementi, ogni elemento è a sua volta un vettore di 10 elementi
self.dataset_name=dataset_name
self.img_shape=img_shape
self.depth = 8000 # dimensione output primo layer
def init(input, output, normalize=True):
layers = [nn.Linear(input, output)]
if normalize:
layers.append(nn.BatchNorm1d(output, 0.8)) # do also batch normalization after a layer, then apply leakyRelu as activation function
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.generator = nn.Sequential(
*init(latentdim + n_classes, self.depth),
#*init(self.depth, self.depth * 2),
#*init(self.depth * 2, self.depth * 4),
#*init(self.depth * 4, self.depth * 8),
nn.Linear(self.depth, self.depth),
nn.Sigmoid(),
Reshape(batch_size, 80, 10, 10),
nn.BatchNorm2d(80),
nn.ConvTranspose2d(80, 30, 3, 1, bias=False),
nn.ConvTranspose2d(30, 3, 20, 4, bias=False) # the output will be 64 (divide by 2 kernel and stride to obtain 32)
)
# torchcat needs to combine tensors --> l'embedding delle features sta tutto qui...
def forward(self, noise, labels):
if self.dataset_name!=preferred_dataset:
print("Requested labels", labels.size(), labels)
# in pratica ogni label che noi vogliamo (es: digit 9, digit 3..) fa da chiave nel dizionario label_embed (una hash table) a un vettore di 10 elementi. Questi 10 elementi sono casuali e diversi per ogni label (es: la label 3 sarà una roba tipo [-0.24, 0-7...] con 10 elementi)
# label_embed(labels) avrà quindi 64 (dimensione di un batch che produciamo alla volta, ergo 64 immagini finte) x10 (ogni label richiesta come detto è tradotta in un vettore di 10 elementi)
# il noise ha dimensione 64x100 (64 immagini, un immagine per ogni riga )
# ECCO IL MISTERIOSO EMBEDDING
# al noise vengono aggiunte 10 colonne, che sono le 10 colonne delle label, quindi ogni riga di gen_input è 100 pixel di noise + 10 float che rappresentano la label che vogliamo per quell'immagine. Questo è l'embedding
gen_input = torch.cat((self.label_embed(labels), noise), -1)
print("conditional vector size", self.label_embed(labels).size())
print("Input to generator size",gen_input.size())
else:
'''
tentativo1: concateniamo al noise direttamente il vettore binario coi 40 attributi, senza nessun mapping. labels sarà una matrice binaria
di 64x40
'''
# print("Requested labels", labels.size(), labels)
gen_input = torch.cat((labels, noise), -1)
# print("conditional vector size", labels.size())
# print("Input to generator size",gen_input.size())
img = self.generator(gen_input)
img = img.view(img.size(0), *self.img_shape) # view è un reshape per ottenere dal vettore in output un immagine con le 64 immagini generate dentro
return img
class Discriminator(nn.Module):
def __init__(self, n_classes, latentdim, batch_size, img_shape, dataset_name):
super(Discriminator, self).__init__()
self.label_embed1 = nn.Embedding(n_classes, n_classes)
self.dropout = 0.4
self.dataset_name=dataset_name
self.depth = 512
def init(input, output, normalize=True):
layers = [nn.Linear(input, output)]
if normalize:
layers.append(nn.Dropout(self.dropout))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.discriminator = nn.Sequential(
*init(n_classes + int(np.prod(img_shape)), self.depth, normalize=False),
*init(self.depth, self.depth),
*init(self.depth, self.depth),
nn.Linear(self.depth, 1),
nn.Sigmoid() # classify as true or false
)
def forward(self, img, labels):
imgs = img.view(img.size(0), -1)
if self.dataset_name==preferred_dataset:
inpu = torch.cat((imgs, labels.float()), -1)
else:
inpu = torch.cat((imgs, self.label_embed1(labels)), -1) # associa all'immagine generata (che contiene più cifre da riconoscere) le labels che erano state richieste
validity = self.discriminator(inpu)
return validity