diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index df57c6ff38ed..7a72e5c24b3b 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -147,9 +147,9 @@ def transforms(examples): accelerator.wait_for_everyone() - # Generate a sample image for visual inspection + # Generate sample images for visual inspection if accelerator.is_main_process: - with torch.no_grad(): + if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: pipeline = DDPMPipeline( unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), scheduler=noise_scheduler, @@ -159,9 +159,11 @@ def transforms(examples): # run pipeline in inference (sample random noise and denoise) images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"] - # denormalize the images and save to tensorboard - images_processed = (images * 255).round().astype("uint8") - accelerator.trackers[0].writer.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch) + # denormalize the images and save to tensorboard + images_processed = (images * 255).round().astype("uint8") + accelerator.trackers[0].writer.add_images( + "test_samples", images_processed.transpose(0, 3, 1, 2), epoch + ) if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: # save the model @@ -184,7 +186,8 @@ def transforms(examples): parser.add_argument("--train_batch_size", type=int, default=16) parser.add_argument("--eval_batch_size", type=int, default=16) parser.add_argument("--num_epochs", type=int, default=100) - parser.add_argument("--save_model_epochs", type=int, default=5) + parser.add_argument("--save_images_epochs", type=int, default=10) + parser.add_argument("--save_model_epochs", type=int, default=10) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--lr_scheduler", type=str, default="cosine")