Code for reproducing some key results of our NeurIPS 2019 submission.
A New Distribution on the Simplex with Auto-Encoding Applications
-
Python (version 3.6.7 or higher)
-
Python requirements are captured in
requirements.txt
- For GPU accelerated TensorFlow:
- run
pip install -r requirements.txt
- run
- For CPU TensorFlow:
- Change
tensorflow-gpu
totensorflow
inrequirements.txt
- run
pip install -r requirements.txt
- Change
- For GPU accelerated TensorFlow:
We utilize the TensorFlow Datasets API. Running any of our code that requires data will automatically download the requisite data.
- Experiment Scripts
experiments_ss_run.py
runs ours semi-supervised learning experiments.experiments_ss_analyze.py
analyzes the results generated byexperiments_ss_run.py
.
- Multivariate Kumaraswamy Code
mv_kumaraswamy_sampler.py
contains a TensorFlow implementation of the Multivariate Kumaraswamy.mv_kumaraswamy_theory.py
contains a symbolic implementation of stick-breaking process that supports Beta and Kumaraswamy distributions.
- Model Files
models_vae.py
contains our VAE models:VariationalAutoEncoder
base class that contains common parameters and functionsAutoEncodingKumaraswamy
our proposed model (MV-Kumaraswamy) that works for dim(z) >= 0AutoEncodingSoftmax
our softmax baseline model that works for dim(z) >= 0KingmaM2
our implementation of Kingma's M2 model, which works for dim(z) > 0
model_lib.py
contains functions used to construct the inference and recognition network operations.model_utils.py
contains data loading/splitting, training routines, and other support functions.
- Miscellaneous
unit_test.py
python mv_kumaraswamy_sampler.py
will plot and show Figure 1 (among other non-utilized figures).python mv_kumaraswamy_theory.py
will plot and show Figures 2-4 (among other non-utilized figures).ars-reparameterization/dirichlet-multinomial.ipynb
(modified from https://github.com/blei-lab/ars-reparameterization) was used to generate Figure 5.python model_utils.py
will plot and show something similar to Figure 6.
To fully rerun our experiments, we recommend defining a new and unused data directory prefix. For example, 'your_results_' suffices.
-
Table 1:
- First run and wait for completion:
python experiments_ss_run.py --dir_prefix your_results_ --num_runs 10 --data_set mnist --num_labelled 600 --dim_z 0
python experiments_ss_run.py --dir_prefix your_results_ --num_runs 10 --data_set mnist --num_labelled 600 --dim_z 2
python experiments_ss_run.py --dir_prefix your_results_ --num_runs 10 --data_set mnist --num_labelled 600 --dim_z 50
- Second run:
python experiments_ss_analyze.py --dir_prefix your_results_ --data_set mnist
- First run and wait for completion:
-
Table 2:
- First run and wait for completion:
python experiments_ss_run.py --dir_prefix your_results_ --num_runs 4 --data_set svhn_cropped --num_labelled 1000 --dim_z 50
- Second run:
python experiments_ss_analyze.py --dir_prefix your_results_ --data_set svhn_cropped
- First run and wait for completion:
results_ss_mnist
andresults_ss_svhn_cropped
contain the results from our original submission. Per reviewer feedback, we augmented this data with some additional baselines. However, these new baselines were not guaranteed to experience the same data folds. To analyze this data call the following:python experiments_ss_analyze.py --dir_prefix results_ss_ --data_set mnist
python experiments_ss_analyze.py --dir_prefix results_ss_ --data_set svhn_cropped
- To prepare our camera-ready submission, we reran all experiments such that all baselines would experience the same data folds. This "new" data is contained in
new_results_ss_mnist
andnew_results_ss_svhn_cropped
. These new sets can be analyzed with:python experiments_ss_analyze.py --dir_prefix new_results_ss_ --data_set mnist
python experiments_ss_analyze.py --dir_prefix new_results_ss_ --data_set svhn_cropped