Skip to content

Commit

Permalink
Fix to latent sampling in the training loop. May improve learning. Th…
Browse files Browse the repository at this point in the history
…anks Yangkang Zhang!
  • Loading branch information
dorarad authored Jun 14, 2022
1 parent d832541 commit 3a9efa4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pytorch_version/training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def fetch_data(dataset, dataset_iter, input_shape, drange_net, device, batches_n
real_c = real_c.to(device).split(batch_gpu)

gen_zs = torch.randn([batches_num * batch_size, *input_shape[1:]], device = device)
gen_zs = [gen_zs.split(batch_gpu) for gen_z in gen_zs.split(batch_size)]
gen_zs = [gen_z.split(batch_gpu) for gen_z in gen_zs.split(batch_size)]

gen_cs = [dataset.get_label(np.random.randint(len(dataset))) for _ in range(batches_num * batch_size)]
gen_cs = torch.from_numpy(np.stack(gen_cs)).pin_memory().to(device)
Expand Down

0 comments on commit 3a9efa4

Please sign in to comment.