Reproduce the VAE experiment on MNIST. Both encoder and decoder are assumed as Gaussian.
Notes:
- Reconstruction loss must match the assumed decoder probability distribution.
- Use
torch.randn()
for multivariate unit Gaussian sampling (much faster thantorch.distributions.MultivariateNormal()
) log(\sigma^2)
from both encoder and decoder must be clipped.- Gradient norm clipping is not required as long as the norm isn't exploding.
- KL loss is expected to increase in the beginning, but should stabilize and converge to lower values later.
Reference: