Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
abarankab committed May 23, 2022
1 parent 31ef9e1 commit 1d42fc8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion ddpm/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def sample(self, batch_size, device, y=None, use_ema=True):
x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)

for t in range(self.num_timesteps - 1, -1, -1):
t_batch = torch.ones(batch_size, device=device) * t
t_batch = torch.tensor([t], device=device).repeat(batch_size)
x = self.remove_noise(x, t_batch, y, use_ema)

if t > 0:
Expand Down
4 changes: 2 additions & 2 deletions ddpm/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def get_diffusion_from_args(args):
else:
betas = generate_linear_schedule(
args.num_timesteps,
1e-4 * 1000 / args.num_timesteps,
0.02 * 1000 / args.num_timesteps,
args.schedule_low * 1000 / args.num_timesteps,
args.schedule_high * 1000 / args.num_timesteps,
)

diffusion = GaussianDiffusion(
Expand Down
21 changes: 13 additions & 8 deletions scripts/train_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def main():
batch_size=batch_size,
shuffle=True,
drop_last=True,
num_workers=-1,
num_workers=2,
))
test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2)

acc_train_loss = 0

Expand Down Expand Up @@ -94,12 +94,6 @@ def main():
loss = diffusion(x)

test_loss += loss.item()

model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"
optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"

torch.save(diffusion.state_dict(), model_filename)
torch.save(optimizer.state_dict(), optim_filename)

if args.use_labels:
samples = diffusion.sample(10, device, y=torch.arange(10, device=device))
Expand All @@ -118,6 +112,13 @@ def main():
})

acc_train_loss = 0

if iteration % args.checkpoint_rate:
model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"
optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"

torch.save(diffusion.state_dict(), model_filename)
torch.save(optimizer.state_dict(), optim_filename)

if args.log_to_wandb:
run.finish()
Expand All @@ -137,13 +138,17 @@ def create_argparser():

log_to_wandb=True,
log_rate=1000,
checkpoint_rate=1000,
log_dir="~/ddpm_logs",
project_name=None,
run_name=run_name,

model_checkpoint=None,
optim_checkpoint=None,

schedule_low=1e-4,
schedule_high=0.02,

device=device,
)
defaults.update(script_utils.diffusion_defaults())
Expand Down

0 comments on commit 1d42fc8

Please sign in to comment.