Tensorflow code of the paper "Self-supervised GAN: Analysis and Improvement with Multi-class Minimax Game".
- Python (2.7, 3.5, 3.6), Numpy, Tensorflow, SciPy, scikit-learn
- Recent NVIDIA GPUs
- Supporting datasets: MNIST, Stacked MNIST (or MNIST 1K), CelebA, CIFAR-10/100, STL-10 and ImageNet 32x32.
- When the code runs first time, the dataset is automatically downloaded in
--data_source=<your path>
.
ss_task
:0
(no SS task),1
(SS task),2
(our MS task)python msdistgan_[*].py --help
for more information.
//Training with MS task
python msdistgan_mnist.py --ss_task=2
There is no evaluation method on this dataset.
//Training with MS task
python msdistgan_mnist1k.py --k=2 --ss_task=2 --is_train=1
//Testing with MS task
python msdistgan_mnist1k.py --k=2 --ss_task=2 --is_train=0
k
:4
(K/4 architecture),2
(K/2 architecture),1
(Full size) (Refer to network architectures of Unrolled GAN [1])
//Training with MS task
python msdistgan_celeba.py --ss_task=2 --is_train=1
//Testing with MS task
python msdistgan_celeba.py --ss_task=2 --is_train=0
//Training with MS task on CIFAR-10 with Resnet and Hinge loss
python msdistgan_cifar.py --db_name=cifar10 --nnet_type=resnet --loss_type=hinge --ss_task=2 --data_source=./data/cifar10/ --is_train=1
//Training with MS task on CIFAR-100 with Resnet and Hinge loss
python msdistgan_cifar.py --db_name=cifar100 --nnet_type=resnet --loss_type=hinge --ss_task=2 --data_source=./data/cifar100/ --is_train=1
//Computing FID (10K-10K) of the pre-trained model of CIFAR-10
python msdistgan_cifar.py --db_name=cifar10 --nnet_type=resnet --loss_type=hinge --ss_task=2 --data_source=./data/cifar10/ --nb_test_real=10000 --nb_test_fake=10000 --is_train=0
//Computing FID (10K-10K) of the pre-trained model of CIFAR-10
python msdistgan_cifar.py --db_name=cifar100 --nnet_type=resnet --loss_type=hinge --ss_task=2 --data_source=./data/cifar100/ --nb_test_real=10000 --nb_test_fake=10000 --is_train=0
//Training with MS task on STL-10 with Resnet and Hinge loss
python msdistgan_stl10.py --nnet_type=resnet --loss_type=hinge --ss_task=2 --is_train=1
//Computing FID (10K-10K) of the pre-trained model of STL-10
python msdistgan_stl10.py --nnet_type=resnet --loss_type=hinge --ss_task=2 --nb_test_real=10000 --nb_test_fake=10000 --is_train=0
//Training with MS task on Imagenet 32x32 with Resnet and Hinge loss
python msdistgan_imagenet32.py --nnet_type=resnet --loss_type=hinge --ss_task=2 --is_train=1
//Computing FID (10K-10K) of the pre-trained model of Imagenet 32x32
python msdistgan_imagenet32.py --nnet_type=resnet --loss_type=hinge --ss_task=2 --nb_test_real=10000 --nb_test_fake=10000 --is_train=0
If you find this work useful in your research, please consider citing:
@InProceedings{tran_2019_neurips_gan,
author = {Tran, Ngoc-Trung and Tran, Viet-Hung and Nguyen, Ngoc-Bao and Yang, Linxiao and Cheung, Ngai-Man},
title = {Self-supervised GAN: Analysis and Improvement with Multi-class Minimax Game},
booktitle = {NeurIPS},
month = {December},
year = {2019}
}
[1] Luke Metz, Ben Poole, David Pfau, Jascha Sohl-Dickstein, "Unrolled Generative Adversarial Networks", ICLR 2016.
We're going to release MS task of standard GAN and our TPU code soon.