-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgan.py
124 lines (100 loc) · 4.87 KB
/
gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import random
import torch
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from src import preprocessing
from src.data import MontevideoFoldersDataset
from src.dl_models.gan import Discriminator
from src.dl_models.unet import UNet2
from src.lib.utils import gradient_penalty
from src.lib.utils import save_gan_checkpoint
# move to package
# Paras and hyperparams
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.manual_seed(50)
PT_PATH = '/clusteruy/home03/DeepCloud/deepCloud/checkpoints/10min_UNet2_sigmoid_mae_f32_60_04-08-2021_20:43.pt'
CSV_PATH='/clusteruy/home03/DeepCloud/deepCloud/data/mvd/train_cosangs_in3_out1.csv'
LEARNING_RATE = 1e-4
BATCH_SIZE = 8
NUM_EPOCHS = 10
LAMBDA_GP = 5
CRITIC_ITERATIONS = 5
FEATURES_D = 32
# Dataloaders
normalize = preprocessing.normalize_pixels()
train_ds = MontevideoFoldersDataset(path='/clusteruy/home03/DeepCloud/deepCloud/data/mvd/train/',
in_channel=3,
out_channel=1,
min_time_diff=5, max_time_diff=15,
csv_path=CSV_PATH,
transform=normalize)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
# Nets
gen = UNet2(n_channels=3, n_classes=1, bilinear=True, filters=32).to(device)
disc = Discriminator(channels_img=1, features_d=FEATURES_D).to(device)
gen.load_state_dict(torch.load(PT_PATH)["model_state_dict"])
# Initializate optimizer
#opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
#opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_disc = optim.RMSprop(disc.parameters(), lr=LEARNING_RATE)
gen.train()
disc.train()
# description of the experiment:
ts = datetime.datetime.now().strftime("%d-%m-%Y_%H:%M")
exp_desc = f'lr({LEARNING_RATE})_opt(rmsprop)_lambda_gp({LAMBDA_GP})_load_dict(10min_UNet2_sigmoid_mae_f32_60_04-08-2021_20:43.pt)_features_d({FEATURES_D})_csv(train_cosangs_in3_out1)'
# tb
writer_gt = SummaryWriter(f"runs/{ts}/{exp_desc}/gt")
writer_pred = SummaryWriter(f"runs/{ts}/{exp_desc}/pred")
writer = SummaryWriter(f"runs/{ts}/{exp_desc}/loss")
step = 0
gen_loss_by_epochs = []
disc_loss_by_epochs = []
for epoch in range(NUM_EPOCHS):
gen_epoch_loss_list = []
disc_epoch_loss_list = []
for batch_idx, (in_frames, gt) in enumerate(train_loader):
in_frames = in_frames.to(device)
gt = gt.to(device)
# Train Critic: max E[critic(real)] - E[critic(fake)]
pred = gen(in_frames)
disc_pred = disc(pred).reshape(-1)
disc_gt = disc(gt).reshape(-1)
gp = gradient_penalty(disc, gt, pred, device=device)
loss_disc = (
-(torch.mean(disc_gt) - torch.mean(disc_pred)) + LAMBDA_GP * gp
)
disc.zero_grad()
loss_disc.backward(retain_graph=True)
opt_disc.step()
# Train Generator: max E[disc(gen_noise)] <-> min -E[disc(gen_noise)]
if batch_idx % CRITIC_ITERATIONS == 0:
disc_pred = disc(pred).reshape(-1)
loss_gen = -torch.mean(disc_pred)
gen.zero_grad()
loss_gen.backward()
opt_gen.step()
# Print losses occasionally and print to tensorboard
if batch_idx % 100 == 0 and batch_idx > 0:
print(
f"Epoch [{epoch+1}/{NUM_EPOCHS}] Batch {batch_idx}/{len(train_loader)} \
Loss D: {loss_disc.item():.4f}, loss G: {loss_gen.item():.4f}"
)
#print(f'{torch.mean(disc_gt)}, {-torch.mean(disc_pred)}, {LAMBDA_GP * gp}')
with torch.no_grad():
writer.add_scalar('Gen Loss', loss_gen, global_step=step)
writer.add_scalar('Disc Loss', loss_disc, global_step=step)
# print images to tb, disabled
img_grid_gt = torchvision.utils.make_grid(gt, normalize=True)
img_grid_pred = torchvision.utils.make_grid(pred, normalize=True)
writer_gt.add_image("gt", img_grid_gt, global_step=step)
writer_pred.add_image("pred", img_grid_pred, global_step=step)
step += 1
gen_epoch_loss_list.append(loss_gen.item())
disc_epoch_loss_list.append(loss_disc.item())
gen_loss_by_epochs.append(sum(gen_epoch_loss_list)/len(gen_epoch_loss_list))
disc_loss_by_epochs.append(sum(disc_epoch_loss_list)/len(disc_epoch_loss_list))
print(f'Epoch {epoch+1}/{NUM_EPOCHS}. Gen_epoch_loss: {gen_loss_by_epochs[-1]}, Disc_epoch_loss: {disc_loss_by_epochs[-1]}')
save_gan_checkpoint(gen, disc, opt_gen, opt_disc, NUM_EPOCHS, gen_loss_by_epochs, disc_loss_by_epochs, exp_desc)