From 197a5dc74785b23fa0ee2aadf16e5fd73a04948f Mon Sep 17 00:00:00 2001 From: jianyunchao Date: Thu, 8 Aug 2024 14:26:00 +0800 Subject: [PATCH] update latte/dit to ms2.3 --- examples/dit/README.md | 68 +++++---- examples/dit/README_CN.md | 144 +++++++++++++++++++ examples/dit/args_train.py | 11 ++ examples/dit/sample.py | 24 ++++ examples/dit/train.py | 14 ++ examples/latte/README.md | 38 +++-- examples/latte/README_CN.md | 238 ++++++++++++++++++++++++++++++++ examples/latte/args_train.py | 10 ++ examples/latte/requirements.txt | 3 + examples/latte/sample.py | 24 ++++ examples/latte/train.py | 14 ++ 11 files changed, 541 insertions(+), 47 deletions(-) create mode 100644 examples/dit/README_CN.md create mode 100644 examples/latte/README_CN.md diff --git a/examples/dit/README.md b/examples/dit/README.md index b909f93127..ad358c84a1 100644 --- a/examples/dit/README.md +++ b/examples/dit/README.md @@ -73,59 +73,41 @@ seed: 42 ddim_sampling: True ``` +The inference speed of the experiments with `256x256` image size is summarized in the following table: + +| model name | context | cards | image size | method | steps | ckpt loading time | graph compile | sample time | +| :--------: | :----------: | :---: | :--------: | :----: | :---: | :---------------: | :-----------: | :---------: | +| dit | 910*-MS2.3.1 | 1 | 256x256 | ddpm | 250 | 16.41s | 82.83s | 58.45s | + Some generated example images are shown below:

- +

- +

-## Model Finetuning - -Now, we support finetuning DiT model on a toy dataset `imagenet_samples/images/`. It consists of three sample images randomly selected from ImageNet dataset and their corresponding class labels. This toy dataset is stored at this [website](https://github.com/wtomin/mindone-assets/tree/main/dit/imagenet_samples). You can also download this toy dataset using: +## Model Training with ImageNet dataset -```bash -bash scripts/download_toy_dataset.sh -``` -Afterwards, the toy dataset is saved in `imagenet_samples/` folder. +For `mindspore>=2.3.0`, it is recommended to use msrun to launch the 4-card distributed training with ImageNet dataset format using the following command: -To finetune DiT model conditioned on class labels on Ascend devices, use: ```bash -python train.py --config configs/training/class_cond_finetune.yaml -``` - -You can adjust the hyper-parameters in the yaml file: -```yaml -# training hyper-params -start_learning_rate: 5e-5 # small lr for finetuning exps. Change it to 1e-4 for regular training tasks. -scheduler: "constant" -warmup_steps: 10 -train_batch_size: 2 -gradient_accumulation_steps: 1 -weight_decay: 0.01 -epochs: 3000 -``` - -After training, the checkpoints will be saved under `output_folder/ckpt/`. - -To run inference with a certain checkpoint file, please first revise `dit_checkpoint` path in the yaml files under `configs/inference/`, for example, -``` -# dit-xl-2-256x256.yaml -dit_checkpoint: "outputs/ckpt/DiT-3000.ckpt" +msrun --worker_num=4 \ + --local_worker_num=4 \ + --bind_core=True \ + --log_dir=msrun_log \ + python train.py \ + -c configs/training/class_cond_train.yaml \ + --data_path PATH_TO_YOUR_DATASET \ + --use_parallel True ``` -Then run `python sample.py -c config-file-path`. - -## Model Training with ImageNet dataset - You can start the distributed training with ImageNet dataset format using the following command ```bash -export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE" mpirun -n 4 python train.py \ -c configs/training/class_cond_train.yaml \ - --dataset_path PATH_TO_YOUR_DATASET \ + --data_path PATH_TO_YOUR_DATASET \ --use_parallel True ``` @@ -144,6 +126,18 @@ to launch a 4P training. For detail usage of the training script, please run bash scripts/run_distributed.sh -h ``` +## Evaluation + +The training speed of the experiments with `256x256` image size is summarized in the following table: + +| model name | context | cards | image size | graph compile | bs | Recompute | sink | step time | train. imgs/s | +| :--------: | :----------: | :---: | :--------: | :-----------: | :--: | :-------: | :--: | :-------: | :-----------: | +| dit | 910*-MS2.3.1 | 1p | 256x256 | 3~5 mins | 64 | OFF | ON | 0.89s | 71.91 | +| dit | 910*-MS2.3.1 | 1p | 256x256 | 3~5 mins | 64 | ON | ON | 0.95s | 67.37 | +| dit | 910*-MS2.3.1 | 4p | 256x256 | 3~5 mins | 64 | ON | ON | 1.03s | 248.52 | +| dit | 910*-MS2.3.1 | 8p | 256x256 | 3~5 mins | 64 | ON | ON | 0.93s | 515.61 | + + # References [1] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2020. 1, 2, 4, 5 diff --git a/examples/dit/README_CN.md b/examples/dit/README_CN.md new file mode 100644 index 0000000000..6d915d6836 --- /dev/null +++ b/examples/dit/README_CN.md @@ -0,0 +1,144 @@ +# 可扩展的基于转换器的扩散模型(DiT) + +## 引言 + +以往常见的扩散模型(例如:稳定扩散模型)使用的是U-Net骨干网络,这缺乏可扩展性。DiT是一类基于转换器架构的新型扩散模型。作者设计了扩散转换器(DiTs),它们遵循视觉转换器(ViTs)[1]的最佳实践。它通过"patchify"将视觉输入作为一系列视觉标记的序列,然后由一系列转换器块(DiT块)处理这些输入。DiT模型和DiT块的结构如下所示: + +

+ +

+

+ 图 1. DiT和DiT块的结构。 [2] +

+ + +DiTs是扩散模型的可扩展架构。作者发现网络复杂性(以Gflops计)与样本质量(以FID计)之间存在强相关性。换句话说,DiT模型越复杂,其在图像生成上的表现就越好。 + +## 开始使用 + +本教程将介绍如何使用MindONE运行推理和微调实验。 + +### 环境设置 + +``` +pip install -r requirements.txt +``` + +### 预训练检查点 + +我们参考[DiT的官方仓库](https://github.com/facebookresearch/DiT)下载预训练检查点。目前,只有两个检查点`DiT-XL-2-256x256`和`DiT-XL-2-512x512`可用。 + +下载`DiT-XL-2-{}x{}.pt`文件后,请将其放置在`models/`文件夹下,然后运行`tools/dit_converter.py`。例如,要转换`models/DiT-XL-2-256x256.pt`,您可以运行以下命令: +```bash +python tools/dit_converter.py --source models/DiT-XL-2-256x256.pt --target models/DiT-XL-2-256x256.ckpt +``` + +此外,还请从[huggingface/stabilityai.co](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)下载VAE检查点,并通过运行以下命令进行转换: +```bash +python tools/vae_converter.py --source path/to/vae/ckpt --target models/sd-vae-ft-mse.ckpt +``` + +转换后,在`models/`下的检查点应如下所示: +```bash +models/ +├── DiT-XL-2-256x256.ckpt +├── DiT-XL-2-512x512.ckpt +└── sd-vae-ft-mse.ckpt +``` + +## 采样 +要在Ascend设备上运行`DiT-XL/2`模型的`256x256`图像尺寸的推理,您可以使用: +```bash +python sample.py -c configs/inference/dit-xl-2-256x256.yaml +``` + +要在Ascend设备上运行`DiT-XL/2`模型的`512x512`图像尺寸的推理,您可以使用: +```bash +python sample.py -c configs/inference/dit-xl-2-512x512.yaml +``` + +要在GPU设备上运行相同的推理,只需按上述命令额外设置`--device_target GPU`。 + +默认情况下,我们以混合精度模式运行DiT推理,其中`amp_level="O2"`。如果您想以全精度模式运行推理,请在推理yaml文件中设置`use_fp16: False`。 + +对于扩散采样,我们使用与[DiT的官方仓库](https://github.com/facebookresearch/DiT)相同的设置: + +- 默认采样器是DDPM采样器,默认采样步数是250。 +- 对于无分类器引导,默认为引导比例是 $4.0$. + +如果您想使用DDIM采样器并采样50步,您可以按以下方式修改推理yaml文件: +```yaml +# 采样 +sampling_steps: 50 +guidance_scale: 4.0 +seed: 42 +ddim_sampling: True +``` + +`256x256`图像大小的实验推理速度总结在以下表格中: + +| 模型名称 | 环境 | 卡数 | 图像尺寸 | 方法 | 步数 | 检查点加载时间 | 编译时间 | 采样时间 | +| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | +| dit | 910*-MS2.3.1 | 1 | 256x256 | ddpm | 250 | 16.41s | 82.83s | 58.45s | + +一些生成的示例图像如下所示: +Some generated example images are shown below: +

+ +

+

+ +

+ +## 使用ImageNet数据集进行模型训练 + +对于`mindspore>=2.3.0`,建议使用msrun启动使用ImageNet数据集格式的4卡分布式训练,使用命令如下: +```bash +msrun --worker_num=4 \ + --local_worker_num=4 \ + --bind_core=True \ + --log_dir=msrun_log \ + python train.py \ + -c configs/training/class_cond_train.yaml \ + --data_path PATH_TO_YOUR_DATASET \ + --use_parallel True +``` + +您可以使用以下命令使用ImageNet数据集格式启动分布式训练: +```bash +mpirun -n 4 python train.py \ + -c configs/training/class_cond_train.yaml \ + --data_path PATH_TO_YOUR_DATASET \ + --use_parallel True +``` + +其中`PATH_TO_YOUR_DATASET`是您的ImageNet数据集的路径,例如`ImageNet2012/train`。 + +对于配备Ascend设备的机器,您也可以使用排名表启动分布式训练。 +请运行 +```bash +bash scripts/run_distributed.sh path_of_the_rank_table 0 4 path_to_your_dataset +``` + +以启动4P训练。有关训练脚本的详细用法,请运行 +```bash +bash scripts/run_distributed.sh -h +``` + +## 评估 + +`256x256`图像大小的实验训练速度总结在以下表格中: + +| 模型名称 | 环境 | 卡数 | 图像尺寸 | 图构建编译 | 批量大小 | 重计算 | 下沉 | 每步时间 | 训练图像/秒 | +| :--------: | :----: | :---: | :--------: | :-----------: | :--: | :-------: | :--: | :-------: | :-----------: | +| dit | 910*-MS2.3.1 | 单卡 | 256x256 | 3~5 分钟 | 64 | 关 | 开 | 0.89秒 | 71.91 | +| dit | 910*-MS2.3.1 | 单卡 | 256x256 | 3~5 分钟 | 64 | 开 | 开 | 0.95秒 | 67.37 | +| dit | 910*-MS2.3.1 | 4卡 | 256x256 | 3~5 分钟 | 64 | 开 | 开 | 1.03秒 | 248.52 | +| dit | 910*-MS2.3.1 | 8卡 | 256x256 | 3~5 分钟 | 64 | 开 | 开 | 0.93秒 | 515.61 | + + +# 参考文献 + +[1] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2020. 1, 2, 4, 5 + +[2] W. Peebles and S. Xie, “Scalable diffusion models with transformers,” in Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4195–4205, 2023 diff --git a/examples/dit/args_train.py b/examples/dit/args_train.py index 49c6499720..feef207720 100644 --- a/examples/dit/args_train.py +++ b/examples/dit/args_train.py @@ -193,6 +193,17 @@ def parse_args(): parser.add_argument("--imagenet_format", type=str2bool, help="Training with ImageNet dataset format") + parser.add_argument( + "--jit_level", + default="O2", + type=str, + choices=["O0", "O1", "O2"], + help="Used to control the compilation optimization level. Supports [“O0”, “O1”, “O2”]." + "O0: Except for optimizations that may affect functionality, all other optimizations are turned off, adopt KernelByKernel execution mode." + "O1: Using commonly used optimizations and automatic operator fusion optimizations, adopt KernelByKernel execution mode." + "O2: Ultimate performance optimization, adopt Sink execution mode.", + ) + abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "")) default_args = parser.parse_args() if default_args.config: diff --git a/examples/dit/sample.py b/examples/dit/sample.py index 8a0474e720..399671769d 100644 --- a/examples/dit/sample.py +++ b/examples/dit/sample.py @@ -39,6 +39,20 @@ def init_env(args): device_target=args.device_target, device_id=device_id, ) + if args.mode == ms.GRAPH_MODE: + try: + if args.jit_level in ["O0", "O1", "O2"]: + ms.set_context(jit_config={"jit_level": args.jit_level}) + logger.info(f"set jit_level: {args.jit_level}.") + else: + logger.warning( + f"Unsupport jit_level: {args.jit_level}. The framework automatically selects the execution method" + ) + except Exception: + logger.warning( + "The current jit_level is not suitable because current MindSpore version does not match," + "please ensure the MindSpore version >= ms2.3.0." + ) if args.precision_mode is not None: ms.set_context(ascend_config={"precision_mode": args.precision_mode}) @@ -119,6 +133,16 @@ def parse_args(): ) parser.add_argument("--ddim_sampling", type=str2bool, default=True, help="Whether to use DDIM for sampling") parser.add_argument("--imagegrid", default=False, type=str2bool, help="Save the image in image-grids format.") + parser.add_argument( + "--jit_level", + default="O2", + type=str, + choices=["O0", "O1", "O2"], + help="Used to control the compilation optimization level. Supports [“O0”, “O1”, “O2”]." + "O0: Except for optimizations that may affect functionality, all other optimizations are turned off, adopt KernelByKernel execution mode." + "O1: Using commonly used optimizations and automatic operator fusion optimizations, adopt KernelByKernel execution mode." + "O2: Ultimate performance optimization, adopt Sink execution mode.", + ) default_args = parser.parse_args() abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "")) if default_args.config: diff --git a/examples/dit/train.py b/examples/dit/train.py index 4ba3a2204b..d59c8114db 100644 --- a/examples/dit/train.py +++ b/examples/dit/train.py @@ -75,6 +75,20 @@ def main(args): max_device_memory=args.max_device_memory, ascend_config=None if args.precision_mode is None else {"precision_mode": args.precision_mode}, ) + if args.ms_mode == ms.GRAPH_MODE: + try: + if args.jit_level in ["O0", "O1", "O2"]: + ms.set_context(jit_config={"jit_level": args.jit_level}) + logger.info(f"set jit_level: {args.jit_level}.") + else: + logger.warning( + f"Unsupport jit_level: {args.jit_level}. The framework automatically selects the execution method" + ) + except Exception: + logger.warning( + "The current jit_level is not suitable because current MindSpore version does not match," + "please ensure the MindSpore version >= ms2.3_0615." + ) set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level)) # 2. model initiate and weight loading diff --git a/examples/latte/README.md b/examples/latte/README.md index e5f04a1d9c..b9b41c19e4 100644 --- a/examples/latte/README.md +++ b/examples/latte/README.md @@ -50,7 +50,7 @@ Instruction on ffmpeg and decord install on EulerOS: 2. install decord, referring to https://github.com/dmlc/decord?tab=readme-ov-file#install-from-source git clone --recursive https://github.com/dmlc/decord cd decord - rm build && mkdir build && cd build + if [ -d build ];then rm build;fi && mkdir build && cd build cmake .. -DUSE_CUDA=0 -DCMAKE_BUILD_TYPE=Release make -j 64 make install @@ -79,6 +79,12 @@ For example, to run inference of `skytimelapse.ckpt` model with the `256x256` im python sample.py -c configs/inference/sky.yaml ``` +The inference speed of the experiments with `256x256` image size is summarized in the following table: + +| model name | context | cards | image size | method | steps | ckpt loading time | compile time | total sample time | +| :--------: | :----------: | :---: | :--------: | :----: | :---: | :---------------: | :----------: | :---------------: | +| latte | 910*-MS2.3.1 | 1 | 256x256 | ddpm | 250 | 19.72s | 101.26s | 537.31s | + Some of the generated results are shown here: @@ -87,9 +93,9 @@ Some of the generated results are shown here: - - - + + +
Example 3

@@ -192,9 +198,20 @@ In case of OOM, please set `enable_flash_attention: True` in the `configs/traini ### 4.3 Distributed Training +For `mindspore>=2.3.0`, it is recommended to use msrun to launch the 4-card distributed training with ImageNet dataset format using the following command: + +``` +msrun --worker_num=4 \ + --local_worker_num=4 \ + --bind_core=True \ + --log_dir=msrun_log \ + python train.py \ + -c path/to/configuration/file \ + --use_parallel True +``` + Taking the 4-card distributed training as an example, you can start the distributed training using: ```bash -export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE" mpirun -n 4 python train.py \ -c path/to/configuration/file \ --use_parallel True @@ -213,11 +230,12 @@ The first number `0` indicates the start index of the training devices, and the The training speed of the experiments with `256x256` image size is summarized in the following table: -| Cards | Recompute | Dataset Sink mode | Embedding Cache|Train. imgs/s | -| --- | --- | --- | --- | --- | -| 1 | OFF | ON | OFF | 62.3 | -| 1 | ON | ON | ON | 93.6 | -| 4 | ON | ON | ON | 368.3 | +| model name | context | cards | image size | graph compile | bs | recompute | sink | cache | step time | train. imgs/s | +| :--------: | :----------: | :---: | :--------: | :-----------: | :---: | :-------: | :--: | :---: | :-------: | :-----------: | +| latte | 910*-MS2.3.1 | 1 | 256x256 | 6~8 mins | 5x16 | OFF | ON | OFF | 1.03s | 77.67 | +| latte | 910*-MS2.3.1 | 1 | 256x256 | 6~8 mins | 1x128 | ON | ON | ON | 1.21s | 105.78 | +| latte | 910*-MS2.3.1 | 4 | 256x256 | 6~8 mins | 1x128 | ON | ON | ON | 1.32s | 387.87 | +| latte | 910*-MS2.3.1 | 8 | 256x256 | 6~8 mins | 1x128 | ON | ON | ON | 1.31s | 781.67 | # References diff --git a/examples/latte/README_CN.md b/examples/latte/README_CN.md new file mode 100644 index 0000000000..afd6d68071 --- /dev/null +++ b/examples/latte/README_CN.md @@ -0,0 +1,238 @@ +# Latte:用于视频生成的潜在扩散转换器 + +## 1. Latte 简介 + +Latte [1] 是一种新颖的潜在扩散转换器,专为视频生成而设计。它基于 DiT(一种用于图像生成的扩散转换器模型)构建。有关 DiT [2] 的介绍,请参见 [DiT 的 README](../dit/README_CN.md)。 + +Latte 首先使用 VAE(变分自编码器)将视频数据压缩到潜在空间,然后根据潜在编码提取空间-时间标记。与 DiT 类似,它堆叠多个转换器块来模拟潜在空间中的视频扩散。如何设计空间和时间块成为一个主要问题。 + +通过实验和分析,他们发现最佳实践是下图中的结构(a)。它交替堆叠空间块和时间块,轮流模拟空间注意力和时间注意力。 + +

+ +

+

+ 图 1. Latte 及其转换器块的结构。 [1] +

+ +与 DiT 类似,Latte 支持无条件视频生成和类别标签条件视频生成。此外,它还支持根据文本标题生成视频。 + +## 2. 快速开始 + +本教程将介绍如何使用 MindONE 运行推理和训练实验。 + +本教程包括: +- [x] 预训练检查点转换; +- [x] 使用预训练的 Latte 检查点进行无条件视频采样; +- [x] 在 Sky TimeLapse 数据集上训练无条件 Latte:支持(1)使用视频训练;和(2)使用嵌入缓存训练; +- [x] 混合精度:支持(1)Float16;(2)BFloat16(将 patch_embedder 设置为 "linear"); +- [x] 独立训练和分布式训练。 +- [ ] 文本到视频 Latte 推理和训练。 + +### 2.1 环境设置 + +`decord` 是视频生成所必需的。如果环境中没有 `decord` 包,请尝试 `pip install eva-decord`。 +``` +1. 安装 ffmpeg 4, 参考 https://ffmpeg.org/releases + wget wget https://ffmpeg.org/releases/ffmpeg-4.0.1.tar.bz2 --no-check-certificate + tar -xvf ffmpeg-4.0.1.tar.bz2 + mv ffmpeg-4.0.1 ffmpeg + cd ffmpeg + ./configure --enable-shared # --enable-shared 是必选项,为了 decord 共享 libavcodec 的编解码库 + make -j 64 + make install +2. 安装 decord, 参考 https://github.com/dmlc/decord?tab=readme-ov-file#install-from-source + git clone --recursive https://github.com/dmlc/decord + cd decord + if [ -d build ];then rm build;fi && mkdir build && cd build + cmake .. -DUSE_CUDA=0 -DCMAKE_BUILD_TYPE=Release + make -j 64 + make install + cd ../python + python3 setup.py install --user +``` + +### 2.2 预训练检查点 + +我们参考 [Latte 官方仓库](https://github.com/Vchitect/Latte/tree/main) 下载预训练检查点。在 FaceForensics、SkyTimelapse、Taichi-HD 和 UCF101 (256x256) 上训练的预训练检查点文件可以从 [huggingface](https://huggingface.co/maxin-cn/Latte/tree/main) 下载。 + +下载 `{}.pt` 文件后,请将其放置在 `models/` 文件夹下,然后运行 `tools/latte_converter.py`。例如,要转换 `models/skytimelapse.pt`,您可以运行: +```bash +python tools/latte_converter.py --source models/skytimelapse.pt --target models/skytimelapse.ckpt +``` + +请同时从 [huggingface/stabilityai.co](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main) 下载 VAE 检查点,并通过运行以下命令进行转换: +```bash +python tools/vae_converter.py --source path/to/vae/ckpt --target models/sd-vae-ft-mse.ckpt +``` + +## 3. 采样 + +例如,要在 Ascend 设备上使用 `256x256` 图像大小运行 `skytimelapse.ckpt` 模型的推理,您可以使用: +```bash +python sample.py -c configs/inference/sky.yaml +``` + +实验中 256x256 图像尺寸在Ascend 910*推理速度总结在以下表格中: + +| 模型名称 | 环境 | 卡数 | 图像尺寸 | 方法 | 步数 | 检查点加载时间 | 编译时间 | 总采样时间 | +| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | +| latte | 910*-MS2.3.1 | 1 | 256x256 | ddpm | 250 | 19.72s | 101.26s | 537.21s | + +这里展示了一些生成结果的例子: + + + + + + + + + + + +
Example 1Example 2Example 3
+

+ 图 2. 从 torch 检查点转换的预训练模型生成的视频。 +

+ +## 4. 训练 + +### 4.1 使用视频训练 + +现在,我们支持在 Sky Timelapse 数据集上训练 Latte 模型,这是一个视频数据集,可以从 https://github.com/weixiong-ur/mdgan 下载。 + +解压缩下载的文件后,您将获得一个名为 `sky_train/` 的文件夹,其中包含所有训练视频帧。文件夹结构类似于: +``` +sky_train/ +├── video_name_0/ +| ├── frame_id_0.jpg +| ├── frame_id_0.jpg +| └── ... +├── video_name_1/ +└── ... +``` + +首先,编辑配置文件 `configs/training/data/sky_video.yaml`。将 `data_folder` 从 `""` 更改为 `sky_train/` 的绝对路径。 + +然后,您可以使用以下命令在 Ascend 设备上开始独立训练: +```bash +python train.py -c configs/training/sky_video.yaml +``` + +要在 GPU 设备上开始训练,只需在上述命令后添加 `--device_target GPU`。 + +默认训练配置是从零开始训练 Latte 模型。批量大小是 $5$,训练周期数是 $3000$,这大约对应于 900k 步。学习率是恒定值 $1e^{-4}$。模型在混合精度模式下训练。默认的 AMP 级别是 `O2`。更多细节请参见 `configs/training/sky_video.yaml`。 + +为了加速训练速度,我们默认在配置文件中使用 `dataset_sink_mode: True`。您也可以设置 `enable_flash_attention: True` 进一步加速训练速度。 + +训练完成后,检查点将保存在 `output_dir/ckpt/` 下。要使用检查点运行推理,请将 `configs/inference/sky.yaml` 中的 `checkpoint` 更改为检查点的路径,然后运行 `python sample.py -c configs/inference/sky.yaml`。 + +训练周期数设置得较大以确保收敛。您可以在准备好时随时终止训练。例如,我们使用了训练了 $1700$ 周期(大约 $500k$ 步)的检查点,并用它进行了推理。以下是一些生成的示例: + + + + + + + + + + + +
Example 1Example 2Example 3
+

+ 图 3. 训练了大约 1700 轮(500k 步)的 Latte 模型生成的视频。 +

+ +### 4.2 使用嵌入缓存训练 + +我们可以通过在运行训练脚本之前缓存数据集的嵌入来加速训练速度。这需要三个步骤: + +- **步骤 1**:将嵌入缓存到一个缓存文件夹。请参阅以下关于如何缓存嵌入的示例。这一步可能需要一些时间。 + +要为 Sky Timelapse 数据集缓存嵌入,请首先确保 `configs/training/sky_video.yaml` 中的 `data_path` 正确设置为名为 `sky_train/` 的文件夹。然后您可以开始使用以下命令保存嵌入: + +```bash +python tools/embedding_cache.py --config configs/training/sky_video.yaml --cache_folder path/to/cache/folder --cache_file_type numpy +``` + +您也可以将 `cache_file_type` 更改为 `mindrecord` 以 `.mindrecord` 文件形式保存嵌入。 + +通常,我们建议使用 `mindrecord` 文件类型,因为它得到 `MindDataset` 的支持,可以更好地加速数据加载。然而,Sky Timelapse 数据集有额外的长视频。使用 `mindrecord` 文件缓存嵌入增加了超过 MindRecord 写入器最大页面大小的风险。因此,我们建议使用 `numpy` 文件。 + +嵌入缓存过程可能需要一些时间,具体取决于视频数据集的大小。在此过程中可能会抛出一些异常。如果抛出了意外的异常,程序将停止,嵌入缓存写入器的状态将显示在屏幕上: + +```bash +Start Video Index: 0. # 要处理的视频索引的开始 +Saving Attempts: 0: save 120 videos, failed 0 videos. # 保存的视频文件数量 +``` + +在这种情况下,您可以从索引 $120$(索引从 0 开始)的视频恢复嵌入缓存。只需附加 `--resume_cache_index 120`,然后运行 `python tools/embedding_cache.py`。它将从第 $120$ 个视频开始缓存嵌入,并保存嵌入,而不会覆盖现有文件。 + +要查看更多用法,请使用 `python tools/embedding_cache.py -h`。 + +- **步骤 2**:将数据集配置文件的 `data_folder` 更改为当前缓存文件夹路径。 + +缓存嵌入后,编辑 `configs/training/data/sky_numpy_video.yaml`,并将 `data_folder` 更改为存储缓存嵌入的文件夹。 + +- **步骤 3**:运行训练脚本。 + +您可以使用以下命令开始使用 Sky TimeLapse 的缓存嵌入数据集进行训练: + +```bash +python train.py -c configs/training/sky_numpy_video.yaml +``` + +请注意,在 `sky_numpy_video.yaml` 中,我们使用了大量帧 $128$ 和较小的采样步长 $1$,这与 `sky_video.yaml`(num_frames=16 和 stride=3)中的设置不同。嵌入缓存使我们能够训练 Latte 生成更多帧,具有更大的帧率。 + +由于内存限制,我们将本地批量大小设置为 $1$,并使用梯度累积步数 $4$。训练周期数为 $1000$,学习率为 $2e^{-5}$。总训练步数约为 1000k。 + +如果出现 OOM(内存不足),请在 `configs/training/sky_numpy_video.yaml` 中设置 `enable_flash_attention: True`。它可以减少内存成本,也可以加速训练速度。 + +### 4.3 分布式训练 + +对于`mindspore>=2.3.0`,建议使用msrun启动使用ImageNet数据集格式的4卡分布式训练,使用命令如下: + +``` +msrun --worker_num=4 \ + --local_worker_num=4 \ + --bind_core=True \ + --log_dir=msrun_log \ + python train.py \ + -c path/to/configuration/file \ + --use_parallel True +``` + +以 4 卡分布式训练为例,您可以使用以下命令开始分布式训练: +```bash +mpirun -n 4 python train.py \ + -c path/to/configuration/file \ + --use_parallel True +``` +其中配置文件可以从 `configs/training/` 文件夹中的 `.yaml` 文件中选择。 + +如果您有 Ascend 设备的RankTable,可以参考 `scripts/run_distributed_sky_numpy_video.sh`,并使用以下命令开始 4 卡分布式训练: +```bash +bash scripts/run_distributed_sky_numpy_video.sh path/to/rank/table 0 4 +``` + +第一个数字 `0` 表示训练设备的起始索引,第二个数字 `4` 表示您要启动的分布式进程总数。 + +## 5. 评估 + +实验中 256x256 图像尺寸在Ascend 910*训练速度总结在以下表格中: + +| 模型名称 | 环境 | 卡数 | 图像尺寸 | 图构建编译 | 批量大小 | 重计算 | 下沉 | 缓存 | 步时间 | 训练图像/秒 | +| :--------: | :----: | :---: | :--------: | :-----------: | :------: | :----: | :--: | :---: | :-------: | :-----------: | +| latte | 910*-MS2.3.1 | 1 | 256x256 | 6~8 分钟 | 5x16 | 关 | 开 | 关 | 1.03秒 | 77.67 | +| latte | 910*-MS2.3.1 | 1 | 256x256 | 6~8 分钟 | 1x128 | 开 | 开 | 开 | 1.21秒 | 105.78 | +| latte | 910*-MS2.3.1 | 4 | 256x256 | 6~8 分钟 | 1x128 | 开 | 开 | 开 | 1.32秒 | 387.87 | +| latte | 910*-MS2.3.1 | 8 | 256x256 | 6~8 分钟 | 1x128 | 开 | 开 | 开 | 1.31秒 | 781.67 | + + +# 参考 + +[1] Xin Ma, Yaohui Wang, Gengyun Jia, Xinyuan Chen, Ziwei Liu, Yuan-Fang Li, Cunjian Chen, Yu Qiao: Latte: Latent Diffusion Transformer for Video Generation. CoRR abs/2401.03048 (2024) + +[2] W. Peebles and S. Xie, “Scalable diffusion models with transformers,” in Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4195–4205, 2023 diff --git a/examples/latte/args_train.py b/examples/latte/args_train.py index b19f4cc1d7..74aeea291f 100644 --- a/examples/latte/args_train.py +++ b/examples/latte/args_train.py @@ -190,6 +190,16 @@ def parse_train_args(parser): ) parser.add_argument("--log_interval", type=int, default=1, help="log interval") + parser.add_argument( + "--jit_level", + default="O2", + type=str, + choices=["O0", "O1", "O2"], + help="Used to control the compilation optimization level. Supports [“O0”, “O1”, “O2”]." + "O0: Except for optimizations that may affect functionality, all other optimizations are turned off, adopt KernelByKernel execution mode." + "O1: Using commonly used optimizations and automatic operator fusion optimizations, adopt KernelByKernel execution mode." + "O2: Ultimate performance optimization, adopt Sink execution mode.", + ) return parser diff --git a/examples/latte/requirements.txt b/examples/latte/requirements.txt index d3c27d71a7..22cd1cf99e 100644 --- a/examples/latte/requirements.txt +++ b/examples/latte/requirements.txt @@ -5,3 +5,6 @@ safetensors albumentations ftfy regex +torch +av +tqdm diff --git a/examples/latte/sample.py b/examples/latte/sample.py index 7987d7dce4..9e17240810 100644 --- a/examples/latte/sample.py +++ b/examples/latte/sample.py @@ -38,6 +38,20 @@ def init_env(args): device_target=args.device_target, device_id=device_id, ) + if args.mode == ms.GRAPH_MODE: + try: + if args.jit_level in ["O0", "O1", "O2"]: + ms.set_context(jit_config={"jit_level": args.jit_level}) + logger.info(f"set jit_level: {args.jit_level}.") + else: + logger.warning( + f"Unsupport jit_level: {args.jit_level}. The framework automatically selects the execution method" + ) + except Exception: + logger.warning( + "The current jit_level is not suitable because current MindSpore version does not match," + "please ensure the MindSpore version >= ms2.3.0." + ) if args.precision_mode is not None: ms.set_context(ascend_config={"precision_mode": args.precision_mode}) return device_id @@ -145,6 +159,16 @@ def parse_args(): help="Whether to use conv2d layer or dense (linear layer) as Patch Embedder.", ) parser.add_argument("--ddim_sampling", type=str2bool, default=True, help="Whether to use DDIM for sampling") + parser.add_argument( + "--jit_level", + default="O2", + type=str, + choices=["O0", "O1", "O2"], + help="Used to control the compilation optimization level. Supports [“O0”, “O1”, “O2”]." + "O0: Except for optimizations that may affect functionality, all other optimizations are turned off, adopt KernelByKernel execution mode." + "O1: Using commonly used optimizations and automatic operator fusion optimizations, adopt KernelByKernel execution mode." + "O2: Ultimate performance optimization, adopt Sink execution mode.", + ) default_args = parser.parse_args() abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "")) if default_args.config: diff --git a/examples/latte/train.py b/examples/latte/train.py index 1a37f20d3c..e2fb00f705 100644 --- a/examples/latte/train.py +++ b/examples/latte/train.py @@ -58,6 +58,20 @@ def main(args): max_device_memory=args.max_device_memory, ascend_config=None if args.precision_mode is None else {"precision_mode": args.precision_mode}, ) + if args.ms_mode == ms.GRAPH_MODE: + try: + if args.jit_level in ["O0", "O1", "O2"]: + ms.set_context(jit_config={"jit_level": args.jit_level}) + logger.info(f"set jit_level: {args.jit_level}.") + else: + logger.warning( + f"Unsupport jit_level: {args.jit_level}. The framework automatically selects the execution method" + ) + except Exception: + logger.warning( + "The current jit_level is not suitable because current MindSpore version does not match," + "please ensure the MindSpore version >= ms2.3_0615." + ) set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level)) # 2. model initiate and weight loading