Skip to content

Implementation of Apple ML's Transformer Flow (or TARFlow) from "Normalising flows are capable generative models" in JAX and Equinox.

Notifications You must be signed in to change notification settings

homerjed/transformer_flow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Transformer flow

Implementation of Apple ML's Transformer Flow (or TARFlow) from Normalising flows are capable generative models in jax and equinox.

Features:

  • jax.vmap & jax.lax.scan construction & forward-pass, for layers respectively for fast compilation and execution,
  • multi-device training, inference and sampling,
  • score-based denoising step (see paper),
  • conditioning via class embedding (for discrete class labels) or adaptive layer-normalisation (for continuous variables, like in DiT),
  • array-typed to-the-teeth for dependable execution with jaxtyping and beartype.

To implement:

  • Guidance
  • Denoising
  • Mixed precision
  • EMA
  • AdaLayerNorm
  • Class embedding
  • Hyperparameter/model saving
  • Uniform noise for dequantisation
@misc{zhai2024normalizingflowscapablegenerative,
      title={Normalizing Flows are Capable Generative Models}, 
      author={Shuangfei Zhai and Ruixiang Zhang and Preetum Nakkiran and David Berthelot and Jiatao Gu and Huangjie Zheng and Tianrong Chen and Miguel Angel Bautista and Navdeep Jaitly and Josh Susskind},
      year={2024},
      eprint={2412.06329},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2412.06329}, 
}

About

Implementation of Apple ML's Transformer Flow (or TARFlow) from "Normalising flows are capable generative models" in JAX and Equinox.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages