Skip to content

Commit

Permalink
update latte/dit to ms2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
jianyunchao committed Aug 28, 2024
1 parent 7c854c5 commit 197a5dc
Show file tree
Hide file tree
Showing 11 changed files with 541 additions and 47 deletions.
68 changes: 31 additions & 37 deletions examples/dit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,59 +73,41 @@ seed: 42
ddim_sampling: True
```
The inference speed of the experiments with `256x256` image size is summarized in the following table:

| model name | context | cards | image size | method | steps | ckpt loading time | graph compile | sample time |
| :--------: | :----------: | :---: | :--------: | :----: | :---: | :---------------: | :-----------: | :---------: |
| dit | 910*-MS2.3.1 | 1 | 256x256 | ddpm | 250 | 16.41s | 82.83s | 58.45s |

Some generated example images are shown below:
<p float="center">
<img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/512x512/class-207.png" width="25%" /><img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/512x512/class-360.png" width="25%" /><img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/512x512/class-417.png" width="25%" /><img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/512x512/class-979.png" width="25%" />
<img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/512x512/class-207.png" width="25%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/512x512/class-360.png" width="25%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/512x512/class-417.png" width="25%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/512x512/class-979.png" width="25%" />
</p>
<p float="center">
<img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/256x256/class-207.png" width="12.5%" /><img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/256x256/class-279.png" width="12.5%" /><img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/256x256/class-360.png" width="12.5%" /><img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/256x256/class-387.png" width="12.5%" /><img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/256x256/class-417.png" width="12.5%" /><img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/256x256/class-88.png" width="12.5%" /><img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/256x256/class-974.png" width="12.5%" /><img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/256x256/class-979.png" width="12.5%" />
<img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-207.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-279.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-360.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-387.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-417.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-88.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-974.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-979.png" width="12.5%" />
</p>

## Model Finetuning
Now, we support finetuning DiT model on a toy dataset `imagenet_samples/images/`. It consists of three sample images randomly selected from ImageNet dataset and their corresponding class labels. This toy dataset is stored at this [website](https://github.com/wtomin/mindone-assets/tree/main/dit/imagenet_samples). You can also download this toy dataset using:
## Model Training with ImageNet dataset

```bash
bash scripts/download_toy_dataset.sh
```
Afterwards, the toy dataset is saved in `imagenet_samples/` folder.
For `mindspore>=2.3.0`, it is recommended to use msrun to launch the 4-card distributed training with ImageNet dataset format using the following command:

To finetune DiT model conditioned on class labels on Ascend devices, use:
```bash
python train.py --config configs/training/class_cond_finetune.yaml
```

You can adjust the hyper-parameters in the yaml file:
```yaml
# training hyper-params
start_learning_rate: 5e-5 # small lr for finetuning exps. Change it to 1e-4 for regular training tasks.
scheduler: "constant"
warmup_steps: 10
train_batch_size: 2
gradient_accumulation_steps: 1
weight_decay: 0.01
epochs: 3000
```

After training, the checkpoints will be saved under `output_folder/ckpt/`.

To run inference with a certain checkpoint file, please first revise `dit_checkpoint` path in the yaml files under `configs/inference/`, for example,
```
# dit-xl-2-256x256.yaml
dit_checkpoint: "outputs/ckpt/DiT-3000.ckpt"
msrun --worker_num=4 \
--local_worker_num=4 \
--bind_core=True \
--log_dir=msrun_log \
python train.py \
-c configs/training/class_cond_train.yaml \
--data_path PATH_TO_YOUR_DATASET \
--use_parallel True
```

Then run `python sample.py -c config-file-path`.
## Model Training with ImageNet dataset
You can start the distributed training with ImageNet dataset format using the following command

```bash
export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE"
mpirun -n 4 python train.py \
-c configs/training/class_cond_train.yaml \
--dataset_path PATH_TO_YOUR_DATASET \
--data_path PATH_TO_YOUR_DATASET \
--use_parallel True
```

Expand All @@ -144,6 +126,18 @@ to launch a 4P training. For detail usage of the training script, please run
bash scripts/run_distributed.sh -h
```

## Evaluation

The training speed of the experiments with `256x256` image size is summarized in the following table:

| model name | context | cards | image size | graph compile | bs | Recompute | sink | step time | train. imgs/s |
| :--------: | :----------: | :---: | :--------: | :-----------: | :--: | :-------: | :--: | :-------: | :-----------: |
| dit | 910*-MS2.3.1 | 1p | 256x256 | 3~5 mins | 64 | OFF | ON | 0.89s | 71.91 |
| dit | 910*-MS2.3.1 | 1p | 256x256 | 3~5 mins | 64 | ON | ON | 0.95s | 67.37 |
| dit | 910*-MS2.3.1 | 4p | 256x256 | 3~5 mins | 64 | ON | ON | 1.03s | 248.52 |
| dit | 910*-MS2.3.1 | 8p | 256x256 | 3~5 mins | 64 | ON | ON | 0.93s | 515.61 |


# References

[1] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2020. 1, 2, 4, 5
Expand Down
144 changes: 144 additions & 0 deletions examples/dit/README_CN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# 可扩展的基于转换器的扩散模型(DiT)

## 引言

以往常见的扩散模型(例如:稳定扩散模型)使用的是U-Net骨干网络,这缺乏可扩展性。DiT是一类基于转换器架构的新型扩散模型。作者设计了扩散转换器(DiTs),它们遵循视觉转换器(ViTs)[1]的最佳实践。它通过"patchify"将视觉输入作为一系列视觉标记的序列,然后由一系列转换器块(DiT块)处理这些输入。DiT模型和DiT块的结构如下所示:

<p align="center">
<img src="https://raw.githubusercontent.com/wtomin/mindone-assets/main/dit/DiT_structure.PNG" width=550 />
</p>
<p align="center">
<em> 图 1. DiT和DiT块的结构。 [<a href="#references">2</a>] </em>
</p>


DiTs是扩散模型的可扩展架构。作者发现网络复杂性(以Gflops计)与样本质量(以FID计)之间存在强相关性。换句话说,DiT模型越复杂,其在图像生成上的表现就越好。

## 开始使用

本教程将介绍如何使用MindONE运行推理和微调实验。

### 环境设置

```
pip install -r requirements.txt
```

### 预训练检查点

我们参考[DiT的官方仓库](https://github.com/facebookresearch/DiT)下载预训练检查点。目前,只有两个检查点`DiT-XL-2-256x256``DiT-XL-2-512x512`可用。

下载`DiT-XL-2-{}x{}.pt`文件后,请将其放置在`models/`文件夹下,然后运行`tools/dit_converter.py`。例如,要转换`models/DiT-XL-2-256x256.pt`,您可以运行以下命令:
```bash
python tools/dit_converter.py --source models/DiT-XL-2-256x256.pt --target models/DiT-XL-2-256x256.ckpt
```

此外,还请从[huggingface/stabilityai.co](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)下载VAE检查点,并通过运行以下命令进行转换:
```bash
python tools/vae_converter.py --source path/to/vae/ckpt --target models/sd-vae-ft-mse.ckpt
```

转换后,在`models/`下的检查点应如下所示:
```bash
models/
├── DiT-XL-2-256x256.ckpt
├── DiT-XL-2-512x512.ckpt
└── sd-vae-ft-mse.ckpt
```

## 采样
要在Ascend设备上运行`DiT-XL/2`模型的`256x256`图像尺寸的推理,您可以使用:
```bash
python sample.py -c configs/inference/dit-xl-2-256x256.yaml
```

要在Ascend设备上运行`DiT-XL/2`模型的`512x512`图像尺寸的推理,您可以使用:
```bash
python sample.py -c configs/inference/dit-xl-2-512x512.yaml
```

要在GPU设备上运行相同的推理,只需按上述命令额外设置`--device_target GPU`

默认情况下,我们以混合精度模式运行DiT推理,其中`amp_level="O2"`。如果您想以全精度模式运行推理,请在推理yaml文件中设置`use_fp16: False`

对于扩散采样,我们使用与[DiT的官方仓库](https://github.com/facebookresearch/DiT)相同的设置:

- 默认采样器是DDPM采样器,默认采样步数是250。
- 对于无分类器引导,默认为引导比例是 $4.0$.

如果您想使用DDIM采样器并采样50步,您可以按以下方式修改推理yaml文件:
```yaml
# 采样
sampling_steps: 50
guidance_scale: 4.0
seed: 42
ddim_sampling: True
```
`256x256`图像大小的实验推理速度总结在以下表格中:

| 模型名称 | 环境 | 卡数 | 图像尺寸 | 方法 | 步数 | 检查点加载时间 | 编译时间 | 采样时间 |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| dit | 910*-MS2.3.1 | 1 | 256x256 | ddpm | 250 | 16.41s | 82.83s | 58.45s |

一些生成的示例图像如下所示:
Some generated example images are shown below:
<p float="center">
<img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/512x512/class-207.png" width="25%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/512x512/class-360.png" width="25%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/512x512/class-417.png" width="25%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/512x512/class-979.png" width="25%" />
</p>
<p float="center">
<img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-207.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-279.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-360.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-387.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-417.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-88.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-974.png" width="12.5%" /><img src="https://raw.githubusercontent.com/jianyunchao/mindone-assets/v0.2.0/dit/256x256/class-979.png" width="12.5%" />
</p>

## 使用ImageNet数据集进行模型训练

对于`mindspore>=2.3.0`,建议使用msrun启动使用ImageNet数据集格式的4卡分布式训练,使用命令如下:
```bash
msrun --worker_num=4 \
--local_worker_num=4 \
--bind_core=True \
--log_dir=msrun_log \
python train.py \
-c configs/training/class_cond_train.yaml \
--data_path PATH_TO_YOUR_DATASET \
--use_parallel True
```

您可以使用以下命令使用ImageNet数据集格式启动分布式训练:
```bash
mpirun -n 4 python train.py \
-c configs/training/class_cond_train.yaml \
--data_path PATH_TO_YOUR_DATASET \
--use_parallel True
```

其中`PATH_TO_YOUR_DATASET`是您的ImageNet数据集的路径,例如`ImageNet2012/train`。

对于配备Ascend设备的机器,您也可以使用排名表启动分布式训练。
请运行
```bash
bash scripts/run_distributed.sh path_of_the_rank_table 0 4 path_to_your_dataset
```

以启动4P训练。有关训练脚本的详细用法,请运行
```bash
bash scripts/run_distributed.sh -h
```

## 评估

`256x256`图像大小的实验训练速度总结在以下表格中:

| 模型名称 | 环境 | 卡数 | 图像尺寸 | 图构建编译 | 批量大小 | 重计算 | 下沉 | 每步时间 | 训练图像/秒 |
| :--------: | :----: | :---: | :--------: | :-----------: | :--: | :-------: | :--: | :-------: | :-----------: |
| dit | 910*-MS2.3.1 | 单卡 | 256x256 | 3~5 分钟 | 64 | 关 | 开 | 0.89秒 | 71.91 |
| dit | 910*-MS2.3.1 | 单卡 | 256x256 | 3~5 分钟 | 64 | 开 | 开 | 0.95秒 | 67.37 |
| dit | 910*-MS2.3.1 | 4卡 | 256x256 | 3~5 分钟 | 64 | 开 | 开 | 1.03秒 | 248.52 |
| dit | 910*-MS2.3.1 | 8卡 | 256x256 | 3~5 分钟 | 64 | 开 | 开 | 0.93秒 | 515.61 |


# 参考文献

[1] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2020. 1, 2, 4, 5

[2] W. Peebles and S. Xie, “Scalable diffusion models with transformers,” in Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4195–4205, 2023
11 changes: 11 additions & 0 deletions examples/dit/args_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,17 @@ def parse_args():

parser.add_argument("--imagenet_format", type=str2bool, help="Training with ImageNet dataset format")

parser.add_argument(
"--jit_level",
default="O2",
type=str,
choices=["O0", "O1", "O2"],
help="Used to control the compilation optimization level. Supports [“O0”, “O1”, “O2”]."
"O0: Except for optimizations that may affect functionality, all other optimizations are turned off, adopt KernelByKernel execution mode."
"O1: Using commonly used optimizations and automatic operator fusion optimizations, adopt KernelByKernel execution mode."
"O2: Ultimate performance optimization, adopt Sink execution mode.",
)

abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ""))
default_args = parser.parse_args()
if default_args.config:
Expand Down
24 changes: 24 additions & 0 deletions examples/dit/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ def init_env(args):
device_target=args.device_target,
device_id=device_id,
)
if args.mode == ms.GRAPH_MODE:
try:
if args.jit_level in ["O0", "O1", "O2"]:
ms.set_context(jit_config={"jit_level": args.jit_level})
logger.info(f"set jit_level: {args.jit_level}.")
else:
logger.warning(
f"Unsupport jit_level: {args.jit_level}. The framework automatically selects the execution method"
)
except Exception:
logger.warning(
"The current jit_level is not suitable because current MindSpore version does not match,"
"please ensure the MindSpore version >= ms2.3.0."
)
if args.precision_mode is not None:
ms.set_context(ascend_config={"precision_mode": args.precision_mode})

Expand Down Expand Up @@ -119,6 +133,16 @@ def parse_args():
)
parser.add_argument("--ddim_sampling", type=str2bool, default=True, help="Whether to use DDIM for sampling")
parser.add_argument("--imagegrid", default=False, type=str2bool, help="Save the image in image-grids format.")
parser.add_argument(
"--jit_level",
default="O2",
type=str,
choices=["O0", "O1", "O2"],
help="Used to control the compilation optimization level. Supports [“O0”, “O1”, “O2”]."
"O0: Except for optimizations that may affect functionality, all other optimizations are turned off, adopt KernelByKernel execution mode."
"O1: Using commonly used optimizations and automatic operator fusion optimizations, adopt KernelByKernel execution mode."
"O2: Ultimate performance optimization, adopt Sink execution mode.",
)
default_args = parser.parse_args()
abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ""))
if default_args.config:
Expand Down
14 changes: 14 additions & 0 deletions examples/dit/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ def main(args):
max_device_memory=args.max_device_memory,
ascend_config=None if args.precision_mode is None else {"precision_mode": args.precision_mode},
)
if args.ms_mode == ms.GRAPH_MODE:
try:
if args.jit_level in ["O0", "O1", "O2"]:
ms.set_context(jit_config={"jit_level": args.jit_level})
logger.info(f"set jit_level: {args.jit_level}.")
else:
logger.warning(
f"Unsupport jit_level: {args.jit_level}. The framework automatically selects the execution method"
)
except Exception:
logger.warning(
"The current jit_level is not suitable because current MindSpore version does not match,"
"please ensure the MindSpore version >= ms2.3_0615."
)
set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level))

# 2. model initiate and weight loading
Expand Down
Loading

0 comments on commit 197a5dc

Please sign in to comment.