MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python/Jax that run on XLA devices including Cloud TPUs and GPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage you to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet your needs.
The goal of this project is to provide reference implementations for latent diffusion models that help developers get started with training, tuning, and serving solutions on XLA devices including Cloud TPUs and GPUs. We started with Stable Diffusion inference on TPUs, but welcome code contributions to grow.
MaxDiffusion supports
- Stable Diffusion 2 base (training and inference)
- Stable Diffusion 2.1 (training and inference)
- Stable Diffusion XL (inference).
WARNING: The training code is purely experimental and is under development.
We recommend starting with a single TPU host and then moving to multihost.
Minimum requirements: Ubuntu Version 22.04, Python 3.10 and Tensorflow >= 2.12.0.
Local development is a convenient way to run MaxDiffusion on a single host.
- Create and SSH to a single-host TPU (v4-8).
- Clone MaxDiffusion in your TPU VM.
- Within the root directory of the MaxDiffusion
git
repo, install dependencies by running:
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
- After installation completes, run the training script:
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name="my_run" base_output_directory="gs://your-bucket/" train_data_dir=gs://jfacevedo-maxdiffusion/laion400m/tf_records
- To generate images, run the following command:
-
Stable Diffusion 2.1
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
-
Stable Diffusion XL Lightning
Multi host inference is supported with sharding annotations:
python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl_lightning.yml run_name="my_run"
-
Stable Diffusion XL
Multi host inference is supported with sharding annotations:
python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"
Single host pmap version:
python -m src.maxdiffusion.generate_sdxl_replicated
Multihost training for Stable Diffusion 2 base can be run using the following command:
TPU_NAME=<your-tpu-name>
ZONE=<your-zone>
PROJECT_ID=<your-project-id>
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --project $PROJECT_ID --worker=all --command="
git clone https://github.com/google/maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run base_output_directory=gs://your-bucket/"
MaxDiffusion started as a fork of Diffusers, a Hugging Face diffusion library written in Python, Pytorch and Jax. MaxDiffusion is compatible with Hugging Face Jax models. MaxDiffusion is more complex and was designed to run distributed across TPU Pods.
Whether you are forking MaxDiffusion for your own needs or intending to contribute back to the community, a full suite of tests can be found in tests
and src/maxdiffusion/tests
.
To run unit tests and lint, simply run:
python -m pytest
ruff check --fix .
The full suite of -end-to end tests is in tests
and src/maxdiffusion/tests
. We run them with a nightly cadance.