Implementation of Apple ML's Transformer Flow (or TARFlow) from Normalising flows are capable generative models in jax
and equinox
.
Features:
jax.vmap
&jax.lax.scan
construction & forward-pass, for layers respectively for fast compilation and execution,- multi-device training, inference and sampling,
- score-based denoising step (see paper),
- conditioning via class embedding (for discrete class labels) or adaptive layer-normalisation (for continuous variables, like in DiT),
- array-typed to-the-teeth for dependable execution with
jaxtyping
andbeartype
.
To implement:
- Guidance
- Denoising
- Mixed precision
- EMA
- AdaLayerNorm
- Class embedding
- Hyperparameter/model saving
- Uniform noise for dequantisation
@misc{zhai2024normalizingflowscapablegenerative,
title={Normalizing Flows are Capable Generative Models},
author={Shuangfei Zhai and Ruixiang Zhang and Preetum Nakkiran and David Berthelot and Jiatao Gu and Huangjie Zheng and Tianrong Chen and Miguel Angel Bautista and Navdeep Jaitly and Josh Susskind},
year={2024},
eprint={2412.06329},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2412.06329},
}