See requirements.txt
for package versions. A Jupyter notebook ssgan_notebook.ipynb
is also provided with Bokeh plots of live updated training metrics.
Note: Implementation and hyperparameters used may vary slightly from what the papers describe.
python ssgan_exp.py --dataset mnist --epochs 10 --perc_labeled 0.00167 --lr 0.003 --noise_dist uniform --use_weight_norm
python ssgan_exp.py --dataset mreo --epochs 100 --perc_labeled 0.08 --lr 0.0006 --noise_dist normal --no_eq_union
Dataset | Labeled training data | Accuracy | Reference | Checkpoint |
---|---|---|---|---|
MNIST | 100 samples | 0.9509 | [1], see Table 1 | model |
MREO | 8% | 0.8658 | [2], see Table 1 | model |
MNIST data is used from torchvision.datasets
.
MREO data can be downloaded from here, we use the compact version.
tar -xvf data_processed_compact.tar.gz mreo_data/
[1]: Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and Xi Chen, "Improved Techniques for Training GANs", 2016.
[2]: Z. Erickson, S. Chernova, and C. C. Kemp, "Semi-Supervised Haptic Material Recognition for Robots using Generative Adversarial Networks", 2017.
- (Dec. 28, 2019) Update to PyTorch 1.3. Add results for MNIST and MREO. Add weight normalization, easier setting of hyperparameters, and data loading improvements.
- (Dec. 26, 2019) Fix bug in labeled loss function, now properly indexes prediction probabilities