Skip to content

shirleyzhu233/realNVP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

realNVP

A PyTorch implementation of the training procedure of Density Estimation Using Real NVP. The original implementation in TensorFlow can be found at https://github.com/tensorflow/models/tree/master/research/real_nvp.

Imlementation Details

This implementation supports training on four datasets, namely CIFAR-10, CelebA, ImageNet 32x32 and ImageNet 64x64. For each dataset, only the training split is used for learning the distribution. Labels are left untouched. Raw data is subject to dequantization, random horizontal flipping and logit transformation (see the paper for details). The network architecture is faithfully reproduced. The same set of hyperparameters as suggested by the paper is set as default. Adam with default parameters are used for optimization. Model performance, evaluated by bits/dim, matches what was reported in the paper.

Samples

The samples are generated from models trained with default parameters. Each iteration corresponds to a minibatch of 64 images.

CIFAR-10

1000 iterations

80000 iterations

CelebA

1000 iterations

60000 iterations

ImageNet 32x32

1000 iterations

80000 iterations

ImageNet 64x64

1000 iterations

60000 iterations

Training

Code runs on a single GPU and has been tested with

  • Python 3.7.2
  • torch 1.0.0
  • numpy 1.15.4
python train.py --dataset=cifar10 --batch_size=64 --base_dim=64 --res_blocks=8 --max_iter=80000
python train.py --dataset=celeba --batch_size=64 --base_dim=32 --res_blocks=2 --max_iter=60000
python train.py --dataset=imnet32 --batch_size=64 --base_dim=32 --res_blocks=4 --max_iter=80000
python train.py --dataset=imnet64 --batch_size=64 --base_dim=32 --res_blocks=2 --max_iter=60000 

About

PyTorch implementation of realNVP

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages