Skip to content

Latest commit

 

History

History
37 lines (24 loc) · 1.29 KB

README.md

File metadata and controls

37 lines (24 loc) · 1.29 KB

vdm

Variational Diffusion Models

Implementation and extension of Variational Diffusion Models (Kingma++21) in jax and equinox.

Synopsis

A Variational Diffusion Model (VDM) is essentially an infinitely deep hierarchical model with an analytic encoding model for each of the latent variables.

This design shares many similarities with a Variational Autoencoder (VAE) but unlike the VAE, the model is fit with three loss terms: the consistency (diffusion) loss, the reconstruction loss, and the prior KL-divergence.

Here training is implemented with the continuous-time depth consistency loss as opposed to a discretised SDE in the DDPM methods.

Features

  • Conditional likelihood modelling,
  • exotic score-network architectures (more to be added),
  • multi-device training and inference.

Usage

pip install variational-diffusion-models 
python main.py

See examples.

alt text

CIFAR10

alt text

MNIST

alt text