From 1d42fc8690a5e1ce41c44c426adadb23f21577ee Mon Sep 17 00:00:00 2001 From: Anthony Baryshnikov Date: Tue, 24 May 2022 02:15:42 +0300 Subject: [PATCH] fixes --- ddpm/diffusion.py | 2 +- ddpm/script_utils.py | 4 ++-- scripts/train_cifar.py | 21 +++++++++++++-------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/ddpm/diffusion.py b/ddpm/diffusion.py index e5d5fd9..6575b4d 100644 --- a/ddpm/diffusion.py +++ b/ddpm/diffusion.py @@ -105,7 +105,7 @@ def sample(self, batch_size, device, y=None, use_ema=True): x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device) for t in range(self.num_timesteps - 1, -1, -1): - t_batch = torch.ones(batch_size, device=device) * t + t_batch = torch.tensor([t], device=device).repeat(batch_size) x = self.remove_noise(x, t_batch, y, use_ema) if t > 0: diff --git a/ddpm/script_utils.py b/ddpm/script_utils.py index cbefaeb..92b7f1a 100644 --- a/ddpm/script_utils.py +++ b/ddpm/script_utils.py @@ -106,8 +106,8 @@ def get_diffusion_from_args(args): else: betas = generate_linear_schedule( args.num_timesteps, - 1e-4 * 1000 / args.num_timesteps, - 0.02 * 1000 / args.num_timesteps, + args.schedule_low * 1000 / args.num_timesteps, + args.schedule_high * 1000 / args.num_timesteps, ) diffusion = GaussianDiffusion( diff --git a/scripts/train_cifar.py b/scripts/train_cifar.py index 2050907..1692ec4 100644 --- a/scripts/train_cifar.py +++ b/scripts/train_cifar.py @@ -54,9 +54,9 @@ def main(): batch_size=batch_size, shuffle=True, drop_last=True, - num_workers=-1, + num_workers=2, )) - test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=4) + test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2) acc_train_loss = 0 @@ -94,12 +94,6 @@ def main(): loss = diffusion(x) test_loss += loss.item() - - model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth" - optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth" - - torch.save(diffusion.state_dict(), model_filename) - torch.save(optimizer.state_dict(), optim_filename) if args.use_labels: samples = diffusion.sample(10, device, y=torch.arange(10, device=device)) @@ -118,6 +112,13 @@ def main(): }) acc_train_loss = 0 + + if iteration % args.checkpoint_rate: + model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth" + optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth" + + torch.save(diffusion.state_dict(), model_filename) + torch.save(optimizer.state_dict(), optim_filename) if args.log_to_wandb: run.finish() @@ -137,6 +138,7 @@ def create_argparser(): log_to_wandb=True, log_rate=1000, + checkpoint_rate=1000, log_dir="~/ddpm_logs", project_name=None, run_name=run_name, @@ -144,6 +146,9 @@ def create_argparser(): model_checkpoint=None, optim_checkpoint=None, + schedule_low=1e-4, + schedule_high=0.02, + device=device, ) defaults.update(script_utils.diffusion_defaults())