Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PixArt-Sigma basic training/inference script #639

Merged
merged 47 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3807190
add PixArt-Sigma inference
zhtmike Jul 31, 2024
ab48d09
Update README.md
zhtmike Jul 31, 2024
ee4541e
support kv compress model inference
zhtmike Aug 1, 2024
b5f1f86
revert to master
zhtmike Aug 9, 2024
c850ddf
update README and drop EMA module
zhtmike Aug 9, 2024
a3fd839
add PixArt-Sigma inference
zhtmike Jul 31, 2024
bde1220
Update README.md
zhtmike Jul 31, 2024
9f930fc
support kv compress model inference
zhtmike Aug 1, 2024
2b96eea
add training script
zhtmike Aug 20, 2024
b5dc7e0
speed up sampling
zhtmike Aug 21, 2024
ac45e0c
add recompute
zhtmike Aug 21, 2024
144bad1
Merge branch 'pixart' into pixart_dev
zhtmike Aug 21, 2024
ab6aa13
drop EMA
zhtmike Aug 21, 2024
95b586c
add training with bucketing (pynative only)
zhtmike Aug 21, 2024
61dd135
disable group strategy
zhtmike Aug 21, 2024
e760b48
add multi-scale info
zhtmike Aug 22, 2024
7c7ce9d
fix dynamic shape for network
zhtmike Aug 22, 2024
853921c
support vae dynamic shape
zhtmike Aug 22, 2024
2af21d1
use KBK for normal loss value
zhtmike Aug 22, 2024
e33451c
fix name
zhtmike Aug 22, 2024
3f014de
support visualization during training
zhtmike Aug 23, 2024
d4f932b
fix step print in bucket training
zhtmike Aug 27, 2024
7f494fb
support CAME optimizer
zhtmike Aug 27, 2024
c1c57ef
add seed for visusalization for distributed mode
zhtmike Aug 28, 2024
6d92239
update adamw_re
zhtmike Aug 28, 2024
c511174
refactor for iddpm
zhtmike Aug 28, 2024
058daa9
fix refactor
zhtmike Aug 28, 2024
54bbccf
speed up inference
zhtmike Aug 28, 2024
3d57c13
support DPM++
zhtmike Aug 28, 2024
b7f8d39
update tqdm and docstring
zhtmike Aug 28, 2024
e4322bc
update README and config
zhtmike Aug 29, 2024
0e06200
fix train and speed up infer
zhtmike Aug 29, 2024
0d3ea98
refactor visual, update README and support batch infer
zhtmike Aug 29, 2024
9744145
speed up save
zhtmike Aug 29, 2024
9cfd3a5
Merge branch 'master' into pixart_dev
zhtmike Aug 29, 2024
c4bad45
fix O1, fix text encode, speed up sampling
zhtmike Aug 30, 2024
b2638b5
fix CAME, add auto_lr, fix plot during train, add data filter
zhtmike Sep 3, 2024
3cdddf4
update README, fix fp16 image visualization, add requirements and upd…
zhtmike Sep 3, 2024
1c04572
clean warning, update config and README
zhtmike Sep 3, 2024
4f82041
fix LR scheduler
zhtmike Sep 4, 2024
9526304
reduce the checkpoint size
zhtmike Sep 5, 2024
3421ab4
refactor batch inference
zhtmike Sep 5, 2024
1d1e4bf
update README
zhtmike Sep 5, 2024
7a40c43
add test case and clean code
zhtmike Sep 6, 2024
9aa6940
move diffusers modification to patch, refactor batch inference
zhtmike Sep 6, 2024
72b1238
update README and minor fix
zhtmike Sep 9, 2024
7ded144
update README
zhtmike Sep 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/pixart_sigma/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
models/*
output/*
samples/*
181 changes: 181 additions & 0 deletions examples/pixart_sigma/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation (Mindspore)

This repo contains Mindspore model definitions, pre-trained weights and inference/sampling code for the [paper](https://arxiv.org/abs/2403.04692) exploring Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation. You can find more visualizations on the [official project page](https://pixart-alpha.github.io/PixArt-sigma-project/).

## Contents

- Main
- [Training](#vanilla-finetune)
- [Inference](#getting-start)
- [Use diffusers: coming soon]
- [Launch Demo: coming soon]
- Guidance
- [Feature extraction: coming soon]
- [One step Generation (DMD): coming soon]
- [LoRA & DoRA: coming soon]
- Benchmark
- [Training](#training)
- [Inference](#inference)

## What's New
- 2024-09-05
- Support fine-tuning and inference for Pixart-Sigma models.

## Dependencies and Installation

- CANN: 8.0.RC2 or later
- Python: 3.9 or later
- Mindspore: 2.3.1

Then, run `pip install -r requirements.txt` to install the necessary packages.

## Getting Start

### Downloading Pretrained Checkpoints

We refer to the [official repository of PixArt-sigma](https://github.com/PixArt-alpha/PixArt-sigma) for pretrained checkpoints downloading.

After downloading the `PixArt-Sigma-XL-2-256x256.pth` and `PixArt-Sigma-XL-2-{}-MS.pth`, please place it under the `models/` directory, and then run `tools/convert.py` for each checkpoint separately. For example, to convert `models/PixArt-Sigma-XL-2-1024-MS.pth`, you can run:

```bash
python tools/convert.py --source models/PixArt-Sigma-XL-2-1024-MS.pth --target models/PixArt-Sigma-XL-2-1024-MS.ckpt
```

> Note: You must have an environment with `PyTorch` installed to run the conversion script.

In addition, please download the [VAE checkpoint](https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae), [T5 checkpoint](https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/text_encoder), [T5 Tokenizer](https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/tokenizer) and put them under `models` directory.


After conversion, the checkpoints under `models/` should be like:
```bash
models/
├── PixArt-Sigma-XL-2-256x256.ckpt
├── PixArt-Sigma-XL-2-512-MS.ckpt
├── PixArt-Sigma-XL-2-1024-MS.ckpt
├── PixArt-Sigma-XL-2-2K-MS.ckpt
├── vae/
├── tokenizer/
└── text_encoder/
```

### Sampling using Pretrained model

You can then run the sampling using `sample.py`. For examples, to sample a 512x512 resolution image, you may run

```bash
python sample.py -c configs/inference/pixart-sigma-512-MS.yaml --prompt "your magic prompt"
```

For higher resolution images, you can choose either `configs/inference/pixart-sigma-1024-MS.yaml` or `configs/inference/pixart-sigma-2K-MS.yaml`.

And to sample an image with a varying aspect ratio, you need to use the flag `--image_width` and `--image_width`. For example, to sample a 512x1024 image, you may run

```bash
python sample.py -c configs/inference/pixart-sigma-1024-MS.yaml --prompt "your magic prompt" --image_width 1024 --image_height 512
```

The following demo image is generated using the following command:

```bash
python sample.py -c configs/inference/pixart-sigma-1024-MS.yaml --image_width 1024 --image_height 512 --seed 1024 --prompt "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
```
<p align="center"><img width="1024" src="https://github.com/user-attachments/assets/bcf12b8d-1077-451b-a6ae-51bbf3c8de7a"/>

You can also generate images using a text file, where the file stores prompts separated by `\n`. Use the following command to generate images:

```bash
python sample.py -c configs/inference/pixart-sigma-1024-MS.yaml --prompt_path path_to_yout_text_file
```

For more detailed usage of the inference script, please run `python sample.py -h`.

### Vanilla Finetune

We support finetune PixArt-Σ model on 910* Ascend device.

#### Prepare the Dataset

- As an example, please download the `diffusiondb-pixelart` dataset from [this link](https://huggingface.co/datasets/jainr3/diffusiondb-pixelart). The dataset is a subset of the larger DiffusionDB 2M dataset, which has been transformed into pixel-style art.

- Once you have the dataset, create a label JSON file in the following format:
```json
[
{
"path": "file1.png",
"prompt": "a beautiful photorealistic painting of cemetery urbex unfinished building building industrial architecture...",
"sharegpt4v": "*caption from ShareGPT4V*",
"height": 512,
"width": 512,
"ratio": 1.0,
},
]
```
- Remember to
- Replace `file1.png` with the actual image file path.
- The `prompt` field contains a description of the image.
- If you have captions generated from ShareGPT4V, add them to the `sharegpt4v` field. Otherwise, copy the label from the `prompt` line.
- `height` and `width` field corresponds to the image height and width, and `ratio` corresponds to the value of `height` / `width`.

#### Finetune the Model:

Use the following command to start the finetuning process:

```bash
python train.py \
-c configs/train/pixart-sigma-512-MS.yaml \
--json_path path_to_your_label_file \
--image_dir path_to_your_image_directory
```
- Remember to
- Replace `path_to_your_label_file` with the actual path to your label JSON file.
- Replace `path_to_your_image_directory` with the directory containing your images.

For more detailed usage of the training script, please run `python train.py -h`.

#### Distributed Training (Optional):

You can launch distributed training using multiple Ascend 910* Devices:

```bash
msrun --worker_num=8 --local_worker_num=8 --log_dir="log" train.py \
-c configs/train/pixart-sigma-512-MS.yaml \
--json_path path_to_your_label_file \
--image_dir path_to_your_image_directory \
--use_parallel True
```
- Remember to
- Replace `path_to_your_label_file` with the actual path to your label JSON file.
- Replace `path_to_your_image_directory` with the directory containing your images.

## Benchmark

### Training
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls add qualititative or visual evaluation results

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


| Context | Optimizer | Global Batch Size | Resolution | Bucket Training | VAE/T5 Cache | Speed (step/s) | FPS (img/s) | Config |
|---------------|-----------|-------------------|------------|-----------------|--------------|----------------|-------------|---------------------------------------------------------------------|
| D910*x4-MS2.3 | CAME | 4 x 64 | 256x256 | No | No | 0.344 | 88.1 | [pixart-sigma-256x256.yaml](configs/train/pixart-sigma-256x256.yaml)|
| D910*x4-MS2.3 | CAME | 4 x 32 | 512 | Yes | No | 0.262 | 33.5 | [pixart-sigma-512-MS.yaml](configs/train/pixart-sigma-512-MS.yaml) |
| D910*x4-MS2.3 | CAME | 4 x 12 | 1024 | Yes | No | 0.142 | 6.8 | [pixart-sigma-1024-MS.yaml](configs/train/pixart-sigma-1024-MS.yaml)|
| D910*x4-MS2.3 | CAME | 4 x 1 | 2048 | Yes | No | 0.114 | 0.5 | [pixart-sigma-2K-MS.yaml](configs/train/pixart-sigma-2K-MS.yaml) |

> Context: {Ascend chip}-{number of NPUs}-{mindspore version}\
> Bucket Training: Training images with different aspect ratios based on bucketing.\
> VAE/T5 Cache: Use the pre-generated T5 Embedding and VAE Cache for training.\
> Speed (step/s): sampling speed measured in the number of training steps per second.\
> FPS (img/s): images per second during training. average training time (s/step) = global batch_size / FPS

### Inference

| Context | Scheduler | Steps | Resolution | Batch Size | Speed (step/s) | Config |
|---------------|-----------|-------|--------------|------------|----------------|-------------------------------------------------------------------------|
| D910*x1-MS2.3 | DPM++ | 20 | 256 x 256 | 1 | 18.04 | [pixart-sigma-256x256.yaml](configs/inference/pixart-sigma-256x256.yaml)|
| D910*x1-MS2.3 | DPM++ | 20 | 512 x 512 | 1 | 15.95 | [pixart-sigma-512-MS.yaml](configs/inference/pixart-sigma-512-MS.yaml) |
| D910*x1-MS2.3 | DPM++ | 20 | 1024 x 1024 | 1 | 4.96 | [pixart-sigma-1024-MS.yaml](configs/inference/pixart-sigma-1024-MS.yaml)|
| D910*x1-MS2.3 | DPM++ | 20 | 2048 x 2048 | 1 | 0.57 | [pixart-sigma-2K-MS.yaml](configs/inference/pixart-sigma-2K-MS.yaml) |

> Context: {Ascend chip}-{number of NPUs}-{mindspore version}.\
> Speed (step/s): sampling speed measured in the number of sampling steps per second.

# References

[1] Junsong Chen, Chongjian Ge, Enze Xie, Yue Wu, Lewei Yao, Xiaozhe Ren, Zhongdao Wang, Ping Luo, Huchuan Lu, Zhenguo Li. PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation. arXiv:2403.04692, 2024.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# model
image_height: 1024
image_width: 1024
sample_size: 128
checkpoint: "models/PixArt-Sigma-XL-2-1024-MS.ckpt"
vae_root: "models/vae"
text_encoder_root: "models/text_encoder"
tokenizer_root: "models/tokenizer"
sd_scale_factor: 0.13025
enable_flash_attention: True
dtype: "fp16"

# sampling
sampling_method: "dpm"
sampling_steps: 20
guidance_scale: 4.5
seed: 42

prompt:
- "A small cactus with a happy face in the Sahara desert."
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# model
image_height: 256
image_width: 256
sample_size: 32
checkpoint: "models/PixArt-Sigma-XL-2-256x256.ckpt"
vae_root: "models/vae"
text_encoder_root: "models/text_encoder"
tokenizer_root: "models/tokenizer"
sd_scale_factor: 0.13025
enable_flash_attention: True
dtype: "fp16"

# sampling
sampling_method: "dpm"
sampling_steps: 20
guidance_scale: 4.5
seed: 42

prompt:
- "A small cactus with a happy face in the Sahara desert."
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# model
image_height: 2048
image_width: 2048
sample_size: 256
checkpoint: ""
vae_root: "models/vae"
text_encoder_root: "models/text_encoder"
tokenizer_root: "models/tokenizer"
sd_scale_factor: 0.13025
enable_flash_attention: True
dtype: "fp16"

kv_compress: True
kv_compress_sampling: "conv"
kv_compress_scale_factor: 2
kv_compress_layer: [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]

# sampling
sampling_method: "dpm"
sampling_steps: 20
guidance_scale: 4.5
seed: 42

prompt:
- "A small cactus with a happy face in the Sahara desert."
20 changes: 20 additions & 0 deletions examples/pixart_sigma/configs/inference/pixart-sigma-2K-MS.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# model
image_height: 2048
image_width: 2048
sample_size: 256
checkpoint: "models/PixArt-Sigma-XL-2-2K-MS.ckpt"
vae_root: "models/vae"
text_encoder_root: "models/text_encoder"
tokenizer_root: "models/tokenizer"
sd_scale_factor: 0.13025
enable_flash_attention: True
dtype: "fp16"

# sampling
sampling_method: "dpm"
sampling_steps: 20
guidance_scale: 4.5
seed: 42

prompt:
- "A small cactus with a happy face in the Sahara desert."
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# model
image_height: 4096
image_width: 4096
sample_size: 512
checkpoint: ""
vae_root: "models/vae"
text_encoder_root: "models/text_encoder"
tokenizer_root: "models/tokenizer"
sd_scale_factor: 0.13025
enable_flash_attention: True
dtype: "fp16"

kv_compress: True
kv_compress_sampling: "conv"
kv_compress_scale_factor: 2
kv_compress_layer: [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]

# sampling
sampling_method: "dpm"
sampling_steps: 20
guidance_scale: 4.5
seed: 42

prompt:
- "A small cactus with a happy face in the Sahara desert."
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# model
image_height: 512
image_width: 512
sample_size: 64
checkpoint: "models/PixArt-Sigma-XL-2-512-MS.ckpt"
vae_root: "models/vae"
text_encoder_root: "models/text_encoder"
tokenizer_root: "models/tokenizer"
sd_scale_factor: 0.13025
enable_flash_attention: True
dtype: "fp16"

# sampling
sampling_method: "dpm"
sampling_steps: 20
guidance_scale: 4.5
seed: 42

prompt:
- "A small cactus with a happy face in the Sahara desert."
44 changes: 44 additions & 0 deletions examples/pixart_sigma/configs/train/pixart-sigma-1024-MS.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# model
sample_size: 128
batch_size: 12
checkpoint: "models/PixArt-Sigma-XL-2-1024-MS.ckpt"
vae_root: "models/vae"
text_encoder_root: "models/text_encoder"
tokenizer_root: "models/tokenizer"
sd_scale_factor: 0.13025
enable_flash_attention: True
dtype: "fp16"

# training hyper-parameters
epochs: 100
scheduler: "constant"
start_learning_rate: 2.0e-5
optim: "came"
came_betas: [0.9, 0.999, 0.9999]
came_eps: [1.0e-30, 1.0e-16]
weight_decay: 0.0
loss_scaler_type: "dynamic"
init_loss_scale: 65536.0
gradient_accumulation_steps: 1
clip_grad: True
max_grad_norm: 0.01
ckpt_save_interval: 5
log_loss_interval: 1
recompute: True
multi_scale: True
class_dropout_prob: 0.1
real_prompt_ratio: 0.5
warmup_steps: 1000
auto_lr: sqrt

# visualization
visualize: False # to prevent potential OOM, it is turned off now.
visualize_interval: 5
sampling_method: "dpm"
sampling_steps: 20
guidance_scale: 4.5
validation_prompts:
- "portrait photo of a girl, photograph, highly detailed face, depth of field"
- "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
- "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece"
Loading