diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md
index 182aca797f..58dbe232f5 100644
--- a/examples/moviegen/README.md
+++ b/examples/moviegen/README.md
@@ -1,55 +1,271 @@
-# Movie Gen Video based on MindSpore
+# Movie Gen based on MindSpore
-This project is built on the [Movie Gen](https://arxiv.org/abs/2410.13720) paper by Meta for video generation, personalization, and editing. We aim to explore an efficient implementation based on MindSpore and Ascend NPUs. See our [report](docs/report.md) for more details.
+This repository implements the [Movie Gen](https://arxiv.org/abs/2410.13720) model presented by Meta.
+
+Movie Gen is a family of foundation models that can natively generate high-fidelity images and videos
+while also possessing the abilities to edit and personalize the videos.
+
+We aim to explore an efficient implementation based on MindSpore and Ascend NPUs.
+See our [report](docs/report.md) for more details.
## ๐ Development Plan
-This project is in an early stage and under active development. We welcome the open-source community to contribute to this project!
+This project is in an early stage and under active development. We welcome the open-source community to contribute to
+this project!
- Temporal Autoencoder (TAE)
- - [x] Inference
- - [x] Training
-- MovieGenVideo-5B (T2I/V)
- - [x] Inference
- - [x] Training stage 1: T2I 256px
- - [x] Training stage 2: T2I/V 256px 256frames
- - [ ] Training stage 3: T2I/V 768px 256frames (under training)
- - [x] Web Demo (Gradio)
-- MovieGenVideo-30B (T2I/V)
- - [x] Inference
- - [ ] Mixed parallelism training (support DP+SP+CP+TP+MP+Zero3, under training)
-- Personalized-MovieGenVideo (PT2V)
- - [ ] Inference
- - [ ] Training
-- MovieGen-Edit
- - [ ] Inference
- - [ ] Training
+ - [x] Inference
+ - [x] Training
+- Movie Gen 5B (T2I/V)
+ - [x] Inference
+ - [x] Training stage 1: T2I 256px
+ - [x] Training stage 2: T2I/V 256px 256frames
+ - [ ] Training stage 3: T2I/V 768px 256frames (under verification)
+ - [x] Web Demo (Gradio)
+- Movie Gen 30B (T2I/V)
+ - [x] Inference
+ - [x] Mixed parallelism training (support Ulysses-SP + ZeRO-3)
+ - [x] Training stage 1: T2I 256px
+ - [x] Training stage 2: T2V 256px 256frames
+ - [ ] Training stage 3: T2I/V 768px 256frames
+- Training with Buckets
+ - [ ] Support variable resolutions and aspect ratios
+ - [ ] Support variable number of frames
+- Video Personalization (PT2V)
+ - [ ] Inference
+ - [ ] Training
+- Video Editing
+ - [ ] Inference
+ - [ ] Training
+- Video Super-Resolution
+ - [ ] Inference
+ - [ ] Training
+
+## Demo
+
+| 256x256x455 | 256x256x455 |
+|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
+| | |
+| Caption
The video showcases a person wearing a blue cap and a plaid shirt, sitting on the ground with a golden retriever dog. The person is seen engaging in an affectionate interaction with the dog, gently stroking its fur and at one point, caressing or scratching behind the dog's ears. Throughout the video, the dog remains relaxed and content, with its mouth slightly open as if panting or smiling. The setting is an outdoor grassy area with fallen leaves or twigs scattered on the ground, under warm lighting that creates a cozy, intimate atmosphere focused on the bonding moment between the person and their canine companion. | Caption
The video features a close-up view of a cat with striking blue eyes and a white furry face adorned with brown and black stripes on its head. Initially, the cat is seen looking directly at the camera with an attentive expression, held gently by a human hand around its neck area against a blurred indoor background with a brown surface. As the video progresses, the cat's gaze becomes more intense and focused, with its whiskers appearing more prominent and alert. The camera zooms in slightly, cropping out some of the surrounding area to bring the cat's face into closer view, maintaining the attentive and engaged demeanor of the feline throughout the sequence. |
+| | |
+| Caption
The video showcases a static image of a bouquet of white roses, with the roses in various stages of bloom. The petals of the roses are delicate and pristine white, contrasting with the soft pink hues visible in their centers. The arrangement is full and lush, with stems protruding outwards. Throughout the video, there are no significant changes in the composition or positioning of the roses, and the background remains consistently blurred, ensuring the floral arrangement remains the focal point. | Caption
The video showcases a majestic snow-capped mountain range against a cloudy sky, with the peaks covered in pristine white snow and jagged rocky outcrops protruding from the slopes. The mountains cast long shadows across the snow-covered terrain below. Initially, the sky is a vivid blue with wispy white clouds, but as the video progresses, the clouds become slightly more dispersed, revealing more of the blue sky. Throughout the video, the overall composition and grandeur of the mountain vistas remain consistent, maintaining the serene and awe-inspiring natural beauty of the landscape. |
+
+## Requirements
+
+
+
+| MindSpore | Ascend Driver | Firmware | CANN toolkit/kernel |
+|:---------:|:-------------:|:-----------:|:-------------------:|
+| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |
+
+
+
+1. Install
+ [CANN 8.0.RC2.beta1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC2.beta1)
+ and MindSpore according to the [official instructions](https://www.mindspore.cn/install).
+2. Install requirements
+ ```shell
+ pip install -r requirements.txt
+ ```
+
+## Model Weights
+
+
+TAE
+
+Download the TAE weights from
+[here](https://download.mindspore.cn/toolkits/mindone/moviegen/tae_ucf101pt_mixkitft-b3b2e364.ckpt) and save them in the
+`models/` directory.
+
-## Temporal Autoencoder (TAE)
+
+Text Encoders
+Downloading and conversion of the text encoders' weights to the `.safetensors` format can be done automatically by using
+the following commands:
-### Requirements
+```shell
+python tools/download_convert_st.py "google/byt5-small"
+python tools/download_convert_st.py "google/ul2"
+```
-| mindspore | ascend driver | firmware |cann toolkit/kernel |
-|:----------:|:--------------:|:-----------:|:------------------:|
-| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |
+If you face an SSL certificate verification error, you can add `--disable_ssl_verify` option.
-### Prepare weights
+
-We use SD3.5 VAE to initialize the spatial layers of TAE, considering they have the same number of latent channels, i.e. 16.
+## Inference
-1. Download SD3.5 VAE from [huggingface](https://huggingface.co/stabilityai/stable-diffusion-3.5-large/tree/main/vae)
+### Generating Text Embeddings
-2. Inflate VAE checkpoint for TAE initialization by
+Due to the large memory footprint of the text encoders, the inference and training pipelines don't support generating
+text embeddings online. Therefore, you need to prepare them in advance by running the following command:
```shell
-python inflate_vae_to_tae.py --src /path/to/sd3.5_vae/diffusion_pytorch_model.safetensors --target models/tae_vae2d.ckpt
+python scripts/inference_text_enc.py \
+--model_name google/ul2 \
+--prompts_file /path/to/prompts.csv \
+--output_path /path/to/output/directory \
+--model_max_length 512
```
-### Prepare datasets
+> [!NOTE]
+> We use the sequence length of 512 tokens for UL2, 256 for MetaCLIP, and 100 for ByT5.
+
+### Text-to-Image
+
+For more detailed instructions, please run `python scripts/inference.py --help`.
+
+```shell
+python scripts/inference.py \
+--config configs/inference/moviegen_t2i_256px.yaml \
+--model.name llama-5B \
+--model.pretrained_model_path /path/to/llama-5B.ckpt \
+--text_emb.ul2_dir /path/to/ul2_embeddings \
+--text_emb.metaclip_dir /path/to/metaclip_embeddings \
+--text_emb.byt5_dir /path/to/byt5_embeddings \
+--image_size 256 455 \
+--batch_size 2
+```
+
+### Text-to-Video
+
+```shell
+python scripts/inference.py \
+--config configs/inference/moviegen_t2i_256px.yaml \
+--model.name llama-5B \
+--model.pretrained_model_path /path/to/llama-5B.ckpt \
+--text_emb.ul2_dir /path/to/ul2_embeddings \
+--text_emb.metaclip_dir /path/to/metaclip_embeddings \
+--text_emb.byt5_dir /path/to/byt5_embeddings \
+--image_size 256 455 \
+--num_frames 32 \
+--batch_size 2 \
+--save_format mp4
+```
+
+### Gradio Demo
+
+To launch the web demo, follow these steps:
+
+1. Install Gradio:
+
+```shell
+pip install gradio
+```
+
+2. Run the demo script with the following configuration. The demo provides 80 pre-computed text prompts to choose from:
+
+```shell
+python scripts/gradio_demo.py \
+--config configs/inference/moviegen_t2i_256px.yaml \
+--model.name llama-5B \
+--model.pretrained_model_path /path/to/llama-5B.ckpt \
+--text_emb.ul2_dir /path/to/ul2-embedding.ckpt \
+--text_emb.metaclip_dir /path/to/metaclip-embedding.ckpt \
+--text_emb.byt5_dir /path/to/byt5-embedding.ckpt \
+--image_size 256 455
+--num_frames 32
+--save_format mp4
+```
+
+Note: Make sure to replace the `/path/to/` placeholders with your actual model and embedding paths.
+
+## Training
+
+Movie Gen is trained jointly on images and videos in 4 stages:
+
+1. Training on images at 256 px resolution.
+2. Joint training on images and videos at 256 px resolution.
+3. Joint training at 768 px resolution.
+4. Fine-tune the model on high quality videos.
+
+Images are treated as single frame videos, enabling the use of the same model to generate both images and videos.
+Compared to video data, paired image-text datasets are easier to scale with diverse concepts and styles,
+and thus joint modeling of image and video leads to better generalization.
+
+To train Movie Gen, run the following commands:
+
+```shell
+scripts/moviegen/stage1_train.sh # for stage 1 training
+scripts/moviegen/stage2_train.sh # for stage 2 training
+scripts/moviegen/stage3_train.sh # for stage 3 training (currently under verification)
+```
+
+### Dataset Preparation
+
+Paths to videos and their corresponding captions should be stored in a CSV file with two columns: `video` and `caption`.
+For example:
+
+```text
+video,caption
+video_folder/part01/vid001.mp4,a cartoon character is walking through
+video_folder/part01/vid002.mp4,a red and white ball with an angry look on its face
+```
+
+### Generating Text Embeddings
+
+Due to the large memory footprint of the text encoders, the inference and training pipelines don't support generating
+text embeddings online. Please refer to the [Generating Text Embeddings](#generating-text-embeddings) section under the
+Inference section for details.
+
+### Cache Video Embedding (Optional)
+
+If you have sufficient storage budget, you can cache the video embeddings to speed up training by using the following
+command:
+
+```shell
+python scripts/inference_tae.py \
+--tae.pretrained=/path/to/tae.ckpt \
+--tae.dtype=bf16 \
+--video_data.folder=/path/to/folder/with/videos/ \
+--output_path=/path/to/output/directory/ \
+--video_data.size=256 \
+--video_data.crop_size=[256,455]
+```
+
+### Performance
+
+Experiments were conducted on Ascend 910* using MindSpore 2.3.1 in Graph mode.
+
+> [!NOTE]
+> We trained all the models using BF16 precision.
+
+| Model | Cards | Stage | Batch size | Resolution | Jit level | Compile time | Recompute | Gradient Acc | ZeRO | Sequence Parallel | TAE Cache | Time (s/step) | Config |
+|:-----:|:-----:|:---------:|:-----------------------:|:-----------------------:|:---------:|:------------:|:-----------------------:|:------------:|:----:|:-----------------:|:---------:|:-------------:|:--------------------------------------------------------------:|
+| 30B | 8 | 2 (T2V) | Video: 1 | 256x256x455 | O1 | 6m | ON | 1 | 3 | 8 shards | Yes | 4.08 | [stage2_t2iv_256px.yaml](configs/train/stage2_t2iv_256px.yaml) |
+| 5B | 8 | 1 (T2I) | 10 | 256x455 | O1 | 3m 40s | ON | 1 | No | No | Yes | 1.29 | [stage1_t2i_256px.yaml](configs/train/stage1_t2i_256px.yaml) |
+| 5B | 8 | 2 (T2I/V) | Image: 1
Video: 1 | 256x455
256 frames | O1 | 6m | ON
(Every 2 blocks) | 5 | 2 | No | Yes | 5.09 | [stage2_t2iv_256px.yaml](configs/train/stage2_t2iv_256px.yaml) |
+| 5B | 8 | 3 (T2I/V) | Image: 1
Video: 1 | 576x1024
256 frames | O1 | 7m 30s | ON | 5 | 2 | No | Yes | 88.5 | [stage3_t2iv_768px.yaml](configs/train/stage3_t2iv_768px.yaml) |
+| 1B | 8 | 1 (T2I) | 10 | 256x455 | O1 | 2m 15s | ON | 1 | No | No | Yes | 0.53 | [stage1_t2i_256px.yaml](configs/train/stage1_t2i_256px.yaml) |
+| 1B | 8 | 2 (T2I/V) | Image: 10
Video: 10 | 256x455
32 frames | O0 | 1m 55s | ON | 1 | No | No | Yes | 2.07 | [stage2_t2iv_256px.yaml](configs/train/stage2_t2iv_256px.yaml) |
+
+### Validation During Training
+
+Validation can be enabled by either setting parameters in the `valid` field of the configuration file
+([example](configs/train/stage1_t2i_256px.yaml)) or by supplying the following arguments to `train.py`:
+
+```shell
+--valid.sampling_steps 10 \
+--valid.frequency 100 \
+--valid.dataset.csv_path /path/to/valid_dataset.csv \
+--valid.dataset.video_folder /path/to/videos \
+--valid.dataset.text_emb_folder.ul2 /path/to/ul2_embeddings \
+--valid.dataset.text_emb_folder.metaclip /path/to/metaclip_embeddings \
+--valid.dataset.text_emb_folder.byt5 /path/to/byt5_embeddings
+```
+
+## Evaluation
+
+Coming soon.
+
+## TAE Training & Evaluation
+
+### Dataset Preparation
+
+We need to prepare a csv annotation file listing the path to each input video related to the root folder, indicated by
+the `video_folder` argument. An example is
-We need to prepare a csv annotation file listing the path to each input video related to the root folder, indicated by the `video_folder` argument. An example is
```
video
dance/vid001.mp4
@@ -58,12 +274,13 @@ dance/vid003.mp4
...
```
-Taking UCF-101 for example, please download the [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) dataset and extract it to `datasets/UCF-101` folder.
-
+Taking UCF-101, for example, please download the [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) dataset and extract
+it to `datasets/UCF-101` folder.
### Training
-TAE is trained to optimize the reconstruction loss, perceptual loss, and the outlier penalty loss (OPL) proposed in the MovieGen paper.
+TAE is trained to optimize the reconstruction loss, perceptual loss, and the outlier penalty loss (OPL) proposed in the
+MovieGen paper.
To launch training, please run
@@ -72,80 +289,36 @@ python scripts/train_tae.py \
--config configs/tae/train/mixed_256x256x32.yaml \
--output_path /path/to/save_ckpt_and_log \
--csv_path /path/to/video_train.csv \
---video_folder /path/to/video_root_folder \
+--folder /path/to/video_root_folder \
```
-Unlike the paper, we found that OPL loss doesn't benefit the training outcome in our ablation study (w/ OPl PSNR is 31.17). Thus we disable OPL loss by default. You may enable it by appending `--use_outlier_penalty_loss True`
+Unlike the paper, we found that OPL loss doesn't benefit the training outcome in our ablation study (with OPL, PSNR is
+31.17). Thus, we disable OPL loss by default. You may enable it by appending `--use_outlier_penalty_loss True`.
For more details on the arguments, please run `python scripts/train_tae.py --help`
-
### Evaluation
To run video reconstruction with the trained TAE model and evaluate the PSNR and SSIM on the test set, please run
```shell
-python scripts/inference_tae.py \
+python scripts/eval_tae.py \
--ckpt_path /path/to/tae.ckpt \
--batch_size 2 \
--num_frames 32 \
--image_size 256 \
--csv_path /path/to/video_test.csv \
---video_folder /path/to/video_root_folder \
+--folder /path/to/video_root_folder \
```
The reconstructed videos will be saved in `samples/recons`.
-#### Performance
+### Performance
Here, we report the training performance and evaluation results on the UCF-101 dataset.
Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
-| model name | cards | batch size | resolution | precision | jit level | graph compile | s/step | PSNR | SSIM | recipe |
-| :--: | :---: | :--: | :--: | :--: | :--: | :--: |:--: | :--: |:--: |:--: |
-| TAE | 1 | 1 | 256x256x32 | bf16 | O0 | 2 mins | 2.18 | 31.35 | 0.92 | [config](configs/tae/train/mixed_256x256x32.yaml) |
-
-
-### Usages for Latent Diffusion Models
-
-
-View more
-
-#### Encoding video
-
-```python
-from mg.models.tae.tae import TemporalAutoencoder, TAE_CONFIG
-
-# may set use_tile=True to save memory
-tae = TemporalAutoencoder(
- pretrained='/path/to/tae.ckpt',
- use_tile=False,
- )
-
-# x - a batch of videos, shape (b c t h w)
-z, _, _ = tae.encode(x)
-
-
-# you may scale z by:
-# z = TAE_CONFIG['scaling_factor'] * (z - TAE_CONFIG['shift_factor'])
-
-```
-
-For detailed arguments, please refer to the docstring in [tae.py](mg/models/tae/tae.py)
-
-### Decoding video latent
-
-```python
-
-# if z is scaled, you should unscale at first:
-# z = z / TAE_CONFIG['scaling_factor'] + TAE_CONFIG['shift_factor']
-
-# z - a batch of video latent, shape (b c t h w)
-x = tae.decode(z)
-
-# for image decoding, set num_target_frames to discard the spurious frames
-x = tae.decode(z, num_target_frames=1)
-```
-
-
+| model name | cards | batch size | resolution | precision | jit level | graph compile | s/step | PSNR | SSIM | recipe |
+|:----------:|:-----:|:----------:|:----------:|:---------:|:---------:|:-------------:|:------:|:-----:|:----:|:-------------------------------------------------:|
+| TAE | 1 | 1 | 256x256x32 | bf16 | O0 | 2 min | 2.18 | 31.35 | 0.92 | [config](configs/tae/train/mixed_256x256x32.yaml) |
diff --git a/examples/moviegen/configs/inference/moviegen_t2i_256px.yaml b/examples/moviegen/configs/inference/moviegen_t2i_256px.yaml
new file mode 100644
index 0000000000..cb3a39b327
--- /dev/null
+++ b/examples/moviegen/configs/inference/moviegen_t2i_256px.yaml
@@ -0,0 +1,34 @@
+env:
+ mode: 0
+ jit_level: O0
+ seed: 42
+ distributed: False
+ debug: False
+
+model:
+ name: llama-5B
+ pretrained_model_path:
+ enable_flash_attention: True
+ dtype: bf16
+
+tae:
+ pretrained: models/tae_ucf101pt_mixkitft-b3b2e364.ckpt
+ use_tile: True
+ dtype: bf16
+
+# Inference parameters
+num_sampling_steps: 50
+sample_method: linear-quadratic
+image_size: [ 256, 256 ]
+num_frames: 1 # image
+text_emb:
+ ul2_dir:
+ metaclip_dir:
+ byt5_dir:
+batch_size: 10
+
+# Saving options
+output_path: ../../samples # the path is relative to this config
+append_timestamp: True
+save_format: png
+save_latent: False
diff --git a/examples/moviegen/configs/tae/train/mixed_256x256x16.yaml b/examples/moviegen/configs/tae/train/mixed_256x256x16.yaml
index af58e62ee1..dea2831835 100644
--- a/examples/moviegen/configs/tae/train/mixed_256x256x16.yaml
+++ b/examples/moviegen/configs/tae/train/mixed_256x256x16.yaml
@@ -1,5 +1,5 @@
# model
-pretrained_model_path: "models/tae_vae2d.ckpt"
+pretrained: "models/tae_vae2d.ckpt"
# loss
perceptual_loss_weight: 1.0
@@ -9,18 +9,16 @@ mixed_strategy: "mixed_video_image"
mixed_image_ratio: 0.2
# data
-dataset_name: "video"
csv_path: "../videocomposer/datasets/webvid5_copy.csv"
-video_folder: "../videocomposer/datasets/webvid5"
-frame_stride: 1
-num_frames: 16
+folder: "../videocomposer/datasets/webvid5"
+sample_stride: 1
+sample_n_frames: 16
image_size: 256
crop_size: 256
# flip: True
# training recipe
seed: 42
-use_discriminator: False
batch_size: 1
clip_grad: True
max_grad_norm: 1.0
@@ -29,7 +27,6 @@ scale_lr: False
weight_decay: 0.
dtype: "fp32"
-amp_level: "O0"
use_recompute: False
epochs: 2000
@@ -42,5 +39,5 @@ use_ema: False
output_path: "outputs/tae_train"
-# ms settting
+# ms setting
jit_level: O0
diff --git a/examples/moviegen/configs/tae/train/mixed_256x256x32.yaml b/examples/moviegen/configs/tae/train/mixed_256x256x32.yaml
index 990ec83d72..8c65f8314f 100644
--- a/examples/moviegen/configs/tae/train/mixed_256x256x32.yaml
+++ b/examples/moviegen/configs/tae/train/mixed_256x256x32.yaml
@@ -1,5 +1,5 @@
# model
-pretrained_model_path: "models/tae_vae2d.ckpt"
+pretrained: "models/tae_vae2d.ckpt"
# loss
perceptual_loss_weight: 1.0
@@ -9,18 +9,16 @@ mixed_strategy: "mixed_video_image"
mixed_image_ratio: 0.2
# data
-dataset_name: "video"
csv_path: "../videocomposer/datasets/webvid5_copy.csv"
-video_folder: "../videocomposer/datasets/webvid5"
-frame_stride: 1
-num_frames: 32
+folder: "../videocomposer/datasets/webvid5"
+sample_stride: 1
+sample_n_frames: 32
image_size: 256
crop_size: 256
# flip: True
# training recipe
seed: 42
-use_discriminator: False
batch_size: 1
clip_grad: True
max_grad_norm: 1.0
@@ -29,7 +27,6 @@ scale_lr: False
weight_decay: 0.
dtype: "bf16"
-amp_level: "O2" # reduce memory cost
use_recompute: True
epochs: 2000
@@ -42,5 +39,5 @@ use_ema: False
output_path: "outputs/tae_train"
-# ms settting
+# ms setting
jit_level: O0
diff --git a/examples/moviegen/configs/train/stage1_t2i_256px.yaml b/examples/moviegen/configs/train/stage1_t2i_256px.yaml
new file mode 100644
index 0000000000..c1ea4755a1
--- /dev/null
+++ b/examples/moviegen/configs/train/stage1_t2i_256px.yaml
@@ -0,0 +1,105 @@
+env:
+ mode: 0
+ jit_level: O0
+ seed: 42
+ distributed: False
+ debug: False
+
+model:
+ name: llama-5B
+ pretrained_model_path:
+ enable_flash_attention: True
+ recompute_every_nth_block: 1
+ dtype: bf16
+
+tae:
+ pretrained: models/tae_ucf101pt_mixkitft-b3b2e364.ckpt
+ use_tile: True
+ dtype: bf16
+
+dataset:
+ csv_path: CSV_PATH
+ video_folder: VIDEO_FOLDER
+ text_emb_folder:
+ ul2: UL2_FOLDER
+ byt5: BYT5_FOLDER
+ empty_text_emb:
+ ul2: EMPTY_TEXT_EMB
+ byt5: EMPTY_TEXT_EMB
+ text_drop_prob: 0.2
+ target_size: [ 256, 455 ]
+ apply_transforms_dataset: True
+ output_columns: [ "video", "ul2_caption", "byt5_caption" ]
+
+dataloader:
+ batch_size: 70
+ shuffle: True
+ num_workers_dataset: 4
+
+train:
+ steps: 30000
+ output_path: ../../output/stage1_t2i_256px # the path is relative to this config
+
+ sequence_parallel:
+ shards: 1
+
+ lr_scheduler:
+ name: constant
+ lr: 1.0e-4
+ warmup_steps: 1000
+
+ lr_reduce_on_plateau:
+ factor: 0.5
+ patience: 50 # in the number of validation steps, i.e., valid.frequency * patience steps
+ mode: min
+ min_delta: 0.01
+ min_lr: 1.0e-6
+
+ optimizer:
+ name: adamw_re
+ eps: 1e-15
+ betas: [ 0.9, 0.999 ]
+ weight_decay: 0.1
+
+ loss_scaler:
+ class_path: mindspore.nn.FixedLossScaleUpdateCell # or DynamicLossScaleUpdateCell in FP16
+ init_args:
+ loss_scale_value: 1
+
+ ema:
+ ema_decay: 0.9999
+ offloading: True
+
+ settings:
+ zero_stage: 0
+ gradient_accumulation_steps: 1
+ clip_grad: True
+ clip_norm: 1.0
+
+ save:
+ ckpt_save_policy: top_k
+ monitor_metric: eval_loss_smoothed
+ ckpt_save_interval: &save_interval 100
+ ckpt_max_keep: 10
+ log_interval: 1
+ save_ema_only: False
+ record_lr: False
+
+valid:
+ sampling_steps: 10
+ frequency: *save_interval # train.save.ckpt_save_interval should be divisible by the frequency
+
+ dataset:
+ csv_path: CSV_PATH
+ video_folder: VIDEO_FOLDER
+ text_emb_folder:
+ ul2: UL2_FOLDER
+ byt5: BYT5_FOLDER
+ target_size: [ 256, 256 ]
+ apply_transforms_dataset: True
+ output_columns: [ "video", "ul2_caption", "byt5_caption" ]
+
+ dataloader:
+ batch_size: 50
+ shuffle: False
+ num_workers_dataset: 4
diff --git a/examples/moviegen/configs/train/stage2_t2iv_256px.yaml b/examples/moviegen/configs/train/stage2_t2iv_256px.yaml
new file mode 100644
index 0000000000..5bc4019c36
--- /dev/null
+++ b/examples/moviegen/configs/train/stage2_t2iv_256px.yaml
@@ -0,0 +1,88 @@
+env:
+ mode: 0
+ jit_level: O0
+ seed: 42
+ distributed: False
+ debug: False
+
+model:
+ name: llama-5B
+ pretrained_model_path:
+ enable_flash_attention: True
+ recompute_every_nth_block: 1
+ dtype: bf16
+
+tae:
+ pretrained: models/tae_ucf101pt_mixkitft-b3b2e364.ckpt
+ use_tile: True
+ dtype: bf16
+
+dataset:
+ csv_path: CSV_PATH
+ video_folder: VIDEO_FOLDER
+ text_emb_folder:
+ ul2: UL2_FOLDER
+ byt5: BYT5_FOLDER
+ empty_text_emb:
+ ul2: EMPTY_TEXT_EMB
+ byt5: EMPTY_TEXT_EMB
+ text_drop_prob: 0.2
+ target_size: [ 256, 455 ]
+ sample_n_frames: 256 # FIXME: add variable frames support.
+ apply_transforms_dataset: True
+ output_columns: [ "video", "ul2_caption", "byt5_caption" ]
+
+dataloader:
+ batch_size:
+ image_batch_size: 1
+ video_batch_size: 1
+ shuffle: True
+ num_workers_dataset: 4
+
+train:
+ steps: 20000
+ output_path: ../../output/stage2_t2iv_256px # the path is relative to this config
+
+ sequence_parallel:
+ shards: 1
+
+ lr_scheduler:
+ name: constant
+ lr: 6.0e-5
+ warmup_steps: 1000
+
+ lr_reduce_on_plateau:
+ factor: 0.5
+ patience: 50 # in the number of validation steps, i.e., valid.frequency * patience steps
+ mode: min
+ min_delta: 0.01
+ min_lr: 1.0e-6
+
+ optimizer:
+ name: adamw_re
+ eps: 1e-15
+ betas: [ 0.9, 0.999 ]
+ weight_decay: 0.1
+
+ loss_scaler:
+ class_path: mindspore.nn.FixedLossScaleUpdateCell # or DynamicLossScaleUpdateCell in FP16
+ init_args:
+ loss_scale_value: 1
+
+ ema:
+ ema_decay: 0.9999
+ offloading: True
+
+ settings:
+ zero_stage: 0
+ gradient_accumulation_steps: 1
+ clip_grad: True
+ clip_norm: 1.0
+
+ save:
+ ckpt_save_policy: latest_k
+ ckpt_save_interval: &save_interval 100
+ ckpt_max_keep: 10
+ log_interval: 1
+ save_ema_only: False
+ record_lr: False
diff --git a/examples/moviegen/configs/train/stage3_t2iv_768px.yaml b/examples/moviegen/configs/train/stage3_t2iv_768px.yaml
new file mode 100644
index 0000000000..cec5e120a1
--- /dev/null
+++ b/examples/moviegen/configs/train/stage3_t2iv_768px.yaml
@@ -0,0 +1,88 @@
+env:
+ mode: 0
+ jit_level: O0
+ seed: 42
+ distributed: False
+ debug: False
+
+model:
+ name: llama-5B
+ pretrained_model_path:
+ enable_flash_attention: True
+ recompute_every_nth_block: 1
+ dtype: bf16
+
+tae:
+ pretrained: models/tae_ucf101pt_mixkitft-b3b2e364.ckpt
+ use_tile: True
+ dtype: bf16
+
+dataset:
+ csv_path: CSV_PATH
+ video_folder: VIDEO_FOLDER
+ text_emb_folder:
+ ul2: UL2_FOLDER
+ byt5: BYT5_FOLDER
+ empty_text_emb:
+ ul2: EMPTY_TEXT_EMB
+ byt5: EMPTY_TEXT_EMB
+ text_drop_prob: 0.2
+ target_size: [ 576, 1024 ]
+ sample_n_frames: 256 # FIXME: add variable frames support.
+ apply_transforms_dataset: True
+ output_columns: [ "video", "ul2_caption", "byt5_caption" ]
+
+dataloader:
+ batch_size:
+ image_batch_size: 1
+ video_batch_size: 1
+ shuffle: True
+ num_workers_dataset: 4
+
+train:
+ steps: 20000
+ output_path: ../../output/stage2_t2iv_256px # the path is relative to this config
+
+ sequence_parallel:
+ shards: 1
+
+ lr_scheduler:
+ name: constant
+ lr: 6.0e-5
+ warmup_steps: 1000
+
+ lr_reduce_on_plateau:
+ factor: 0.5
+ patience: 50 # in the number of validation steps, i.e., valid.frequency * patience steps
+ mode: min
+ min_delta: 0.01
+ min_lr: 1.0e-6
+
+ optimizer:
+ name: adamw_re
+ eps: 1e-15
+ betas: [ 0.9, 0.999 ]
+ weight_decay: 0.1
+
+ loss_scaler:
+ class_path: mindspore.nn.FixedLossScaleUpdateCell # or DynamicLossScaleUpdateCell in FP16
+ init_args:
+ loss_scale_value: 1
+
+ ema:
+ ema_decay: 0.9999
+ offloading: True
+
+ settings:
+ zero_stage: 0
+ gradient_accumulation_steps: 1
+ clip_grad: True
+ clip_norm: 1.0
+
+ save:
+ ckpt_save_policy: latest_k
+ ckpt_save_interval: &save_interval 100
+ ckpt_max_keep: 10
+ log_interval: 1
+ save_ema_only: False
+ record_lr: False
diff --git a/examples/moviegen/docs/report.md b/examples/moviegen/docs/report.md
index 548509b433..b5db4203c3 100644
--- a/examples/moviegen/docs/report.md
+++ b/examples/moviegen/docs/report.md
@@ -1,49 +1,61 @@
# MindSpore Movie Gen Report
-[Movie Gen](https://ai.meta.com/static-resource/movie-gen-research-paper) is a family of foundation models that can natively generate high-fidelity images, videos, and audio. Meta researchers found that scaling the training data, compute, and model parameters of the transformer-based ([LLaMa3](https://arxiv.org/abs/2407.21783)) model trained with [Flow Matching](https://arxiv.org/abs/2210.02747) yields high-quality generative models for video or audio.
+[Movie Gen](https://ai.meta.com/static-resource/movie-gen-research-paper) is a family of foundation models that can
+natively generate high-fidelity images, videos, and audio. Meta researchers found that scaling the training data,
+compute, and model parameters of the transformer-based ([LLaMa3](https://arxiv.org/abs/2407.21783)) model trained
+with [Flow Matching](https://arxiv.org/abs/2210.02747) yields high-quality generative models for video or audio.
-Movie Gen supports text-to-video/image generation (MovieGenVideo), video personalization (PersonalizedMovieGen), and video editing (MovieGenEdit).
+Movie Gen supports text-to-video/image generation (MovieGenVideo), video personalization (PersonalizedMovieGen), and
+video editing (MovieGenEdit).
-In this report, we will focus on MovieGenVideo and explore how to implement it with MindSpore, enabling model scaling and training efficiency.
+In this report, we will focus on MovieGenVideo and explore how to implement it with MindSpore, enabling model scaling
+and training efficiency.
At this moment, we support training MovieGenVideo with the following configuration.
-| model scale | image | 256px @256 | 768px @256 |
-| ---- | ----- | --- | --- |
-| 1B | โ
| โ
| โ
|
-| 5B | โ
| โ
| ๐ |
-| 30B | ๐ | ๐ | TODO |
-
-Here โ
means that training accuracy has been verified on a small-scale dataset, and ๐ means training is supported but the accuracy is under verfication.
-
+| model scale | image | 256px @256 | 768px @256 |
+|-------------|-------|------------|------------|
+| 1B | โ
| โ
| โ
|
+| 5B | โ
| โ
| ๐ |
+| 30B | ๐ | โ
| TODO |
+Here โ
means that training accuracy has been verified on a small-scale dataset, and ๐ means training is supported, but
+the accuracy is under verification.
## Temporal Autoencoder (TAE)
-TAE is used to encode the RGB pixel-space videos and images into a spatio-temporally compressed latent space. In particular, the input is compressed by 8x across each spatial dimension H and W, and the temporal dimension T. We follow the framework of Meta Movie Gen [[1](#references)] as below.
+TAE is used to encode the RGB pixel-space videos and images into a spatio-temporally compressed latent space. In
+particular, the input is compressed by 8x across each spatial dimension H and W, and the temporal dimension T. We follow
+the framework of Meta Movie Gen [[1](#references)] as below.
Figure 1. Video Encoding and Decoding using TAE
-TAE inflates an image autoencoder by adding 1-D temporal convolution in resnet blocks and attention blocks. Temporal compression is done by injecting temporal downsample and upsample layers.
-
+TAE inflates an image autoencoder by adding 1-D temporal convolution in resnet blocks and attention blocks. Temporal
+compression is done by injecting temporal downsample and upsample layers.
### Key design & implementation
-In this section, we explore the design and implementation details not illustrated in the Movie Gen paper. For example, how to perform padding and initialization for the Conv 2.5-D layers and how to configure the training frames.
+In this section, we explore the design and implementation details not illustrated in the Movie Gen paper. For example,
+how to perform padding and initialization for the Conv 2.5-D layers and how to configure the training frames.
#### SD3.5 VAE as the base image encoder
-In TAE, the number of channels of the latent space is 16 (C=16). It can help improve both the reconstruction and the generation performance compared to C=4 used in OpenSora or SDXL vae.
-
-We choose to use the [VAE]() in Stable Diffusion 3.5 as the image encoder to build TAE for it has the same number of latent channels and can generalize well in image generation.
+In TAE, the number of channels of the latent space is 16 (C=16). It can help improve both the reconstruction and the
+generation performance compared to C=4 used in OpenSora or SDXL vae.
+We choose to use the [VAE]() in Stable Diffusion 3.5 as the image encoder to build TAE for it has the same number of
+latent channels and can generalize well in image generation.
#### Conv2.5d implementation
-Firstly, we replace the Conv2d in VAE with Conv2.5d, which consists of a 2D spatial convolution followed by a 1D temporal convolution.
+Firstly, we replace the Conv2d in VAE with Conv2.5d, which consists of a 2D spatial convolution followed by a 1D
+temporal convolution.
-For 1D temporal convolution, we set kernel size 3, stride 1, symmetric replicate padding with padding size (1, 1), and input/output channels the same as spatial conv. We initialize the kernel weight so as to preserve the spatial features (i.e. preserve image encoding after temporal initialization). Therefore, we propose to use `centric` initialization as illustrated below.
+For 1D temporal convolution, we set kernel size 3, stride 1, symmetric replicate padding with padding size (1, 1), and
+input/output channels the same as spatial conv. We initialize the kernel weight to preserve the spatial features
+(i.e., preserve image encoding after temporal initialization). Therefore, we propose to use `centric` initialization as
+illustrated below.
```python
w = self.conv_temp.weight
@@ -53,43 +65,48 @@ for i in range(ch):
value[i, i, 0, 1] = 1
w.set_data(ms.Tensor(value, dtype=ms.float32))
```
-#### Temporal Downsampling
+#### Temporal Downsampling
Paper: "Temporal downsampling is performed via strided convolution with a stride of 2".
-Our implementation: the strided convolution is computed using conv1d of kernel size 3, stride 2, and symmetric replicate padding. `centric` initialization (as mentioned in the above conv2.5 section) is used to initialize the conv kernel weight.
+Our implementation: the strided convolution is computed using conv1d of kernel size 3, stride 2, and symmetric replicate
+padding. `centric` initialization (as mentioned in the above conv2.5 section) is used to initialize the conv kernel
+weight.
-To achieve 8x temporal compression, we apply 3 temporal downsampling layers, each placed after the spatial downsampling layer in the first 3 levels.
+To achieve 8x temporal compression, we apply 3 temporal downsampling layers, each placed after the spatial downsampling
+layer in the first 3 levels.
#### Temporal Upsampling
+
Paper: "upsampling by nearest-neighbor interpolation followed by convolution"
Our design:
-1. nearest-neighbour interpolation along the temporal dimension
-2. conv1d: kernel size 3, stride 1, symmetric replicate padding, and `centric` initialization.
-
-To achieve 8x temporal compression, we apply 3 temporal upsampling layers, each placed after the spatial upsampling layer of the last 3 levels.
+1. nearest-neighbour interpolation along the temporal dimension
+2. conv1d: kernel size 3, stride 1, symmetric replicate padding, and `centric` initialization.
+To achieve 8x temporal compression, we apply 3 temporal upsampling layers, each placed after the spatial upsampling
+layer of the last 3 levels.
### Evaluation
-We conduct experiments to verify our implementation's effectiveness on the [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) dataset containing 13,320 videos. We split the videos into training and test sets by 8:2.
+We conduct experiments to verify our implementation's effectiveness on
+the [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) dataset containing 13,320 videos. We split the videos into
+training and test sets by 8:2.
The training performance on MindSpore 2.3.1 and Ascend 910* and the accuracy on the test set are as follows.
-| model name | cards | batch size | resolution | precision | OPL Loss | s/step | PSNR | SSIM |
-| :--: | :---: | :--: | :--: | :--: | :--: |:--: | :--: |:--: |
-| TAE | 1 | 1 | 256x256x32 | bf16 | OFF | 2.18 | 31.35 | 0.92 |
-| TAE | 1 | 1 | 256x256x32 | bf16 | ON | 2.18 | 31.17 | 0.92 |
+| model name | cards | batch size | resolution | precision | OPL Loss | s/step | PSNR | SSIM |
+|:----------:|:-----:|:----------:|:----------:|:---------:|:--------:|:------:|:-----:|:----:|
+| TAE | 1 | 1 | 256x256x32 | bf16 | OFF | 2.18 | 31.35 | 0.92 |
+| TAE | 1 | 1 | 256x256x32 | bf16 | ON | 2.18 | 31.17 | 0.92 |
-
-The hyper-parameters we used are as follows.
+The hyperparameters we used are as follows.
```yaml
-kl loss weight: 1.0e-06
-perceptual and reconstruction loss weight: 1.0
+kl loss weight: 1.0e-06
+perceptual and reconstruction loss weight: 1.0
outlier penalty loss weight: 1.0
optimizer: adamw
learning rate: 1e-5
@@ -103,7 +120,8 @@ Here is the comparison between the origin videos (left) and the videos reconstru
-We further fine-tune the TAE model on the mixkit dataset, a high-quality video dataset in 1080P resolution. Here are the results.
+We further fine-tune the TAE model on the mixkit dataset, a high-quality video dataset in 1080P resolution. Here are the
+results.
@@ -116,7 +134,8 @@ The fine-tuned TAE is then used in MovieGenVideo transformer training as shown b
### Architecture
-MovieGenVideo uses the [LLaMa3](https://arxiv.org/abs/2407.21783) backbone architecture for the joint image-video generation
+MovieGenVideo uses the [LLaMa3](https://arxiv.org/abs/2407.21783) backbone architecture for the joint image-video
+generation
model, enabling confident scaling of the model size while maintaining efficient training, as shown in the figure below.
@@ -142,50 +161,31 @@ We have implemented the MovieGenVideo architecture in the following variations:
| 5B | 32 | 3072 | 8192 | 24 |
| 30B | 48 | 6144 | 16384 | 48 |
-
Detailed code implementation can be referred to:
[LLaMa3 Backbone](https://github.com/hadipash/mindone/blob/5aa1e4dc91d71934905319ba984704d4d4a62f8b/examples/moviegen/mg/models/llama/network.py#L273),
[Transformer Block](https://github.com/hadipash/mindone/blob/5aa1e4dc91d71934905319ba984704d4d4a62f8b/examples/moviegen/mg/models/llama/network.py#L52).
-
-### Mixed Parallelism
-
-Movie Gen employs multiple parallelism to achieve model scaling and training efficiency, including [fully sharded
-data parallelism](https://arxiv.org/abs/2304.11277)(FSDP), [tensor parallelism](https://arxiv.org/abs/1909.08053)(TP),
-[sequence parallelism](https://arxiv.org/abs/2105.13120)(SP), and context parallelism (CP).
-
-Currently, our implementation supports MovieGenVideo training with TP, SP, CP, and DP.
-
-- **Tensor-parallelism (TP)**
- shards the weights of linear layers either along columns or rows, and results in each NPU involved in the sharding
- performing _tp-size_ less work (FLOPs) and generating _tp-size_ fewer activations for column-parallel shards and
- consuming _tp-size_ fewer activations for row-parallel shards. The cost of performing such a sharding is the addition
- of all-reduce communication overheads in both the forward (row-parallel) and backward (column-parallel) passes.
-
- Our implementation can be referred to [TP](https://github.com/hadipash/mindone/blob/5aa1e4dc91d71934905319ba984704d4d4a62f8b/examples/moviegen/mg/models/llama/block.py#L59) and
- [FusedTP](https://github.com/hadipash/mindone/blob/5aa1e4dc91d71934905319ba984704d4d4a62f8b/examples/moviegen/mg/models/llama/block.py#L91)
-- **Sequence-parallelism (SP)**
- builds upon TP to also allow the sharding of the input over the sequence dimension for layers which are replicated and
- in which each sequence element can be treated independently. Such layers, e.g., LayerNorm, would otherwise perform
- duplicate compute and generate identical (and thus replicated) activations across the TP-group.
-
- Our implementation can be referred to [SP](https://github.com/hadipash/mindone/blob/5aa1e4dc91d71934905319ba984704d4d4a62f8b/examples/moviegen/mg/models/llama/network.py#L494)
-- **Context-parallelism (CP)**
-
- enables a partial sharding over the sequence dimension for the _sequence-dependent softmax-attention operation_. CP
- leverages the insight that for any given (_source_ (_context_), _target_ (_query_)) sequences pair, _softmax-attention
- is only sequence-dependent over the context and not the query_. Therefore, in the case of self-attention where the
- input source and target sequences are identical, CP allows the attention computation to be performed with only an
- all-gather for the $K$ and $V$ projections (instead of $Q$, $K$, and $V$) in the forward pass, and a reduce-scatter
- for their associated gradients in the backward.
-
- Our implementation can be referred to [CP Attention](https://github.com/hadipash/mindone/blob/5aa1e4dc91d71934905319ba984704d4d4a62f8b/examples/moviegen/mg/models/llama/block.py#L210) and
- [CP FlashAttention](https://github.com/hadipash/mindone/blob/5aa1e4dc91d71934905319ba984704d4d4a62f8b/examples/moviegen/mg/models/llama/block.py#L340)
-- **Fully sharded data parallel (FSDP)** shards the model, optimizer, and gradients across all data-parallel NPUs,
- synchronously gathering and scattering parameters and gradients throughout each training step.
-
- In our implementation, we use data parallelism with [Zero3](https://github.com/hadipash/mindone/blob/5aa1e4dc91d71934905319ba984704d4d4a62f8b/mindone/trainers/zero.py#L75) to serve the similar purpose of sharding the model parameters across multiple NPUs and optimize memory usage.
-
+### Sequence Parallelism
+
+The official [Movie Gen](https://ai.meta.com/research/publications/movie-gen-a-cast-of-media-foundation-models/) employs
+3D parallelism to enable model-level scaling across three dimensions: the number of parameters, input tokens, and
+dataset size, while also allowing horizontal scale-out to additional NPUs. It leverages a combination
+of [fully sharded data parallelism](https://arxiv.org/abs/2304.11277), [tensor parallelism](https://arxiv.org/abs/1909.08053), [sequence parallelism](https://arxiv.org/abs/2205.05198),
+and [context parallelism](https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/context_parallel.html).
+
+Inspired by recent developments in long-sequence parallelism ([Ulysses-SP](https://arxiv.org/abs/2309.14509)
+and [USP](https://arxiv.org/abs/2405.07719)), we implement model parallelism
+using [Ulysses-SP](https://arxiv.org/abs/2309.14509) together with [ZeRO-3](https://arxiv.org/abs/1910.02054),instead of
+the approach used in Movie Gen. Ulysses-SP utilizes `All2ALL` communication for segments of the QKV tensors, drastically
+reducing communication costs compared to sequence parallelism implemented
+in [Megatron-LM](https://arxiv.org/abs/2405.07719), [DSP](https://arxiv.org/abs/2403.10266), as well as the sequence
+parallelism mentioned
+in [Movie Gen](https://ai.meta.com/research/publications/movie-gen-a-cast-of-media-foundation-models/). Alongside
+ZeRO-3, it achieves similar memory efficiency to [[Megatron-LM](https://arxiv.org/abs/2405.07719)]. Experimental results
+show that using Ulysses-SP + ZeRO-3, we can train a model of similar scale compared to 3D parallelism, with over 2x
+speed boost in training, corroborating the findings
+in [Megatron-LM](https://arxiv.org/abs/2405.07719), [Ulysses-SP](https://arxiv.org/abs/2309.14509),
+and [DSP](https://arxiv.org/abs/2403.10266).
### Training Details
@@ -197,19 +197,20 @@ Training is performed in multiple stages for better efficiency:
- Stage 1: Text-to-image pre-raining on 256 px images.
- Stage 2: T2I/V joint training on low-resolution images and videos of 256 px.
- Following the paper, we double the spatial [PE](#learnable-positional-embedding-pe) layers to accommodate
- various aspect ratios, add new temporal PE layers to support up to 32 latent frames and initialize spatial PE layers
- from the T2I model with 2x expansion.
+ Following the paper, we double the spatial [PE](#learnable-positional-embedding-pe) layers to accommodate
+ various aspect ratios, add new temporal PE layers to support up to 32 latent frames and initialize spatial PE layers
+ from the T2I model with 2x expansion.
- Stage 3: T2I/V joint training on high-resolution images and videos of 768 px.
- For this stage, we expand the spatial PE layers by 3x.
-
+ For this stage, we expand the spatial PE layers by 3x.
#### Training Objective
-Following the paper, we trained the transformer with [Flow Matching](https://arxiv.org/abs/2210.02747) with a simple linear interpolation scheme.
+Following the paper, we trained the transformer with [Flow Matching](https://arxiv.org/abs/2210.02747) with a simple
+linear interpolation scheme.
It is trained to predict the velocity $V_t = \frac{dX_t}{dt}$ which teaches it to 'move' the sample $X_t$
in the direction of the video sample $X_1$. The ground truth velocity is derived by:
-$$V_t = X_1 - (1-\sigma_{min})X_0$$. Note that this simple interpolation scheme naturally ensures zero terminal SNR at $t=0$.
+$$V_t = X_1 - (1-\sigma_{min})X_0$$. Note that this simple interpolation scheme naturally ensures zero terminal SNR
+at $t=0$.
#### Learning Rate Scheduling
@@ -225,7 +226,9 @@ and monitor the validation loss throughout training.
### Bucketization for variable duration and size (under verification)
-To support training with diverse video lengths and aspect ratios, we have integrated the data bucketing feature in [hpcai-opensora](https://github.com/mindspore-lab/mindone/tree/master/examples/opensora_hpcai#multi-resolution-training). This feature is under verification.
+To support training with diverse video lengths and aspect ratios, we have integrated the data bucketing feature
+in [hpcai-opensora](https://github.com/mindspore-lab/mindone/tree/master/examples/opensora_hpcai#multi-resolution-training).
+This feature is under verification.
### Inference Details
@@ -243,27 +246,28 @@ with 25 quadratically placed steps. The linear-quadratic strategy is predicated
inference steps are pivotal in setting up the scene and motion of the video since most changes occur in the first
solver steps.
-Our implementation can be referred to [here](https://github.com/hadipash/mindone/blob/movie_gen/examples/moviegen/mg/schedulers/rectified_flow.py#L55-L61)
+Our implementation can be referred
+to [here](https://github.com/hadipash/mindone/blob/movie_gen/examples/moviegen/mg/schedulers/rectified_flow.py#L55-L61)
[//]: # (TODO: fix the link above)
-
-
### Evaluation
-To verify the effectiveness of our design and implementation, we perform 3-stage training on a [mixkit](https://mixkit.co/) subset, consisting of 100 HQ videos up to 1080P.
+To verify the effectiveness of our design and implementation, we perform 3-stage training on
+a [mixkit](https://mixkit.co/) subset consisting of 100 HQ videos up to 1080P.
Experiments were conducted on Ascend 910* using MindSpore 2.3.1 in graph mode.
-
-| model scale | cards | stage | batch size | resolution | recompute | TAE Cache | time (s/step) | recipe |
+| Model | Cards | Stage | Batch size | Resolution | Recompute | TAE Cache | Time (s/step) | Recipe |
|:-----:|:-----:|:---------:|:-----------------------:|:-----------------------:|:-----------------------:|:---------:|:-------------:|:-----------------------------------------------------------------:|
-| 30B | 8 | 2 (T2V) | Video: 1 | 256x256x455 | ON | ON | 23.8 | [stage2_t2iv_256px.yaml](../configs/train/stage2_t2iv_256px.yaml) |
-| 5B | 8 | 1 (T2I) | 10 | 256x455 | ON | ON | 1.29 | [stage1_t2i_256px.yaml](../configs/train/stage1_t2i_256px.yaml) |
-| 5B | 8 | 2 (T2I/V) | Image: 1
Video: 1 | 256x455
256 frames | ON
(Every 2 blocks) | ON | 5.09 | [stage2_t2iv_256px.yaml](../configs/train/stage2_t2iv_256px.yaml) |
-| 5B | 8 | 3 (T2I/V) | Image: 1
Video: 1 | 576x1024
256 frames | ON | ON | 88.5 | [stage3_t2iv_768px.yaml](../configs/train/stage3_t2iv_768px.yaml) |
-| 1B | 8 | 1 (T2I) | 10 | 256x455 | ON | ON | 0.53 | [stage1_t2i_256px.yaml](../configs/train/stage1_t2i_256px.yaml) |
-| 1B | 8 | 2 (T2I/V) | Image: 10
Video: 10 | 256x455
32 frames | ON | ON | 2.07 | [stage2_t2iv_256px.yaml](../configs/train/stage2_t2iv_256px.yaml) |
+| 30B | 8 | 2 (T2V) | Video: 1 | 256x256x455 | ON | ON | 4.08 | [stage2_t2iv_256px.yaml](../configs/train/stage2_t2iv_256px.yaml) |
+| 5B | 8 | 1 (T2I) | 10 | 256x455 | ON | ON | 1.29 | [stage1_t2i_256px.yaml](../configs/train/stage1_t2i_256px.yaml) |
+| 5B | 8 | 2 (T2I/V) | Image: 1
Video: 1 | 256x455
256 frames | ON
(Every 2 blocks) | ON | 5.09 | [stage2_t2iv_256px.yaml](../configs/train/stage2_t2iv_256px.yaml) |
+| 5B | 8 | 3 (T2I/V) | Image: 1
Video: 1 | 576x1024
256 frames | ON | ON | 88.5 | [stage3_t2iv_768px.yaml](../configs/train/stage3_t2iv_768px.yaml) |
+| 1B | 8 | 1 (T2I) | 10 | 256x455 | ON | ON | 0.53 | [stage1_t2i_256px.yaml](../configs/train/stage1_t2i_256px.yaml) |
+| 1B | 8 | 2 (T2I/V) | Image: 10
Video: 10 | 256x455
32 frames | ON | ON | 2.07 | [stage2_t2iv_256px.yaml](../configs/train/stage2_t2iv_256px.yaml) |
+
+> [!NOTE]
> All the models are trained with BF16 precision.
#### Detailed Training Scripts
@@ -284,7 +288,7 @@ export GLOG_v=2
stage1_dir=output/stage1_t2i_256px/$(date +"%Y.%m.%d-%H.%M.%S")
msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$stage1_dir" \
-python train.py \
+python scripts/train.py \
--config configs/train/stage1_t2i_256px.yaml \
--env.mode 0 \
--env.jit_level O1 \
@@ -326,7 +330,7 @@ export GLOG_v=2
stage2_dir=output/stage2_t2iv_256px/$(date +"%Y.%m.%d-%H.%M.%S")
msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$stage2_dir" \
-python train.py \
+python scripts/train.py \
--config configs/train/stage2_t2iv_256px.yaml \
--env.mode 0 \
--env.jit_level O1 \
@@ -370,7 +374,7 @@ export GLOG_v=2
stage3_dir=output/stage3_t2iv_768px/$(date +"%Y.%m.%d-%H.%M.%S")
msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$stage3_dir" \
-python train.py \
+python scripts/train.py \
--config configs/train/stage3_t2iv_768px.yaml \
--env.mode 0 \
--env.jit_level O1 \
@@ -394,8 +398,6 @@ python train.py \
-
-
### Generated Video Examples
| 256x256x455 | 256x256x455 |
@@ -405,8 +407,8 @@ python train.py \
| | |
| Caption
The video showcases a static image of a bouquet of white roses, with the roses in various stages of bloom. The petals of the roses are delicate and pristine white, contrasting with the soft pink hues visible in their centers. The arrangement is full and lush, with stems protruding outwards. Throughout the video, there are no significant changes in the composition or positioning of the roses, and the background remains consistently blurred, ensuring the floral arrangement remains the focal point. | Caption
The video showcases a majestic snow-capped mountain range against a cloudy sky, with the peaks covered in pristine white snow and jagged rocky outcrops protruding from the slopes. The mountains cast long shadows across the snow-covered terrain below. Initially, the sky is a vivid blue with wispy white clouds, but as the video progresses, the clouds become slightly more dispersed, revealing more of the blue sky. Throughout the video, the overall composition and grandeur of the mountain vistas remain consistent, maintaining the serene and awe-inspiring natural beauty of the landscape. |
-
## References
+
[1] The Movie Gen team @ Meta. Movie Gen: A Cast of Media Foundation Models. 2024
diff --git a/examples/moviegen/mg/__init__.py b/examples/moviegen/mg/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/moviegen/mg/acceleration/__init__.py b/examples/moviegen/mg/acceleration/__init__.py
new file mode 100644
index 0000000000..51f2a4fd22
--- /dev/null
+++ b/examples/moviegen/mg/acceleration/__init__.py
@@ -0,0 +1,2 @@
+from .communications import *
+from .parallel_states import *
diff --git a/examples/moviegen/mg/acceleration/communications.py b/examples/moviegen/mg/acceleration/communications.py
new file mode 100644
index 0000000000..127a8c3b5e
--- /dev/null
+++ b/examples/moviegen/mg/acceleration/communications.py
@@ -0,0 +1,71 @@
+from typing import Callable, Literal, Tuple
+
+import mindspore.nn as nn
+import mindspore.ops as ops
+from mindspore import Tensor
+from mindspore.communication import GlobalComm, get_group_size, get_rank
+
+__all__ = ["SplitFowardGatherBackward", "GatherFowardSplitBackward"]
+
+
+def _split(x: Tensor, dim: int, rank: int, world_size: int) -> Tensor:
+ dim_size = x.shape[dim]
+ tensor_list = x.split(dim_size // world_size, axis=dim)
+ x = tensor_list[rank]
+ return x
+
+
+def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor:
+ x = x.swapaxes(0, dim)
+ x = func(x)
+ x = x.swapaxes(dim, 0)
+ return x
+
+
+class SplitFowardGatherBackward(nn.Cell):
+ def __init__(
+ self, dim: int = 0, grad_scale: Literal["up", "down"] = "down", group: str = GlobalComm.WORLD_COMM_GROUP
+ ) -> None:
+ super().__init__()
+ self.dim = dim
+ self.rank = get_rank(group)
+ self.world_size = get_group_size(group)
+ self.gather = ops.AllGather(group=group)
+
+ if grad_scale == "up":
+ self.scale = self.world_size
+ else:
+ self.scale = 1 / self.world_size
+
+ def construct(self, x: Tensor) -> Tensor:
+ return _split(x, self.dim, self.rank, self.world_size)
+
+ def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]:
+ dout = dout * self.scale
+ dout = _communicate_along_dim(dout, self.dim, self.gather)
+ return (dout,)
+
+
+class GatherFowardSplitBackward(nn.Cell):
+ def __init__(
+ self, dim: int = 0, grad_scale: Literal["up", "down"] = "up", group: str = GlobalComm.WORLD_COMM_GROUP
+ ) -> None:
+ super().__init__()
+ self.dim = dim
+ self.rank = get_rank(group)
+ self.world_size = get_group_size(group)
+ self.gather = ops.AllGather(group=group)
+
+ if grad_scale == "up":
+ self.scale = self.world_size
+ else:
+ self.scale = 1 / self.world_size
+
+ def construct(self, x: Tensor) -> Tensor:
+ x = _communicate_along_dim(x, self.dim, self.gather)
+ return x
+
+ def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]:
+ dout = dout * self.scale
+ dout = _split(dout, self.dim, self.rank, self.world_size)
+ return (dout,)
diff --git a/examples/moviegen/mg/acceleration/parallel_states.py b/examples/moviegen/mg/acceleration/parallel_states.py
new file mode 100644
index 0000000000..15391aac6b
--- /dev/null
+++ b/examples/moviegen/mg/acceleration/parallel_states.py
@@ -0,0 +1,35 @@
+from typing import Optional
+
+from mindspore.communication import create_group, get_group_size, get_rank
+
+__all__ = ["set_sequence_parallel_group", "get_sequence_parallel_group", "create_parallel_group"]
+
+_GLOBAL_PARALLEL_GROUPS = dict()
+
+
+def set_sequence_parallel_group(group: str) -> None:
+ _GLOBAL_PARALLEL_GROUPS["sequence"] = group
+
+
+def get_sequence_parallel_group() -> Optional[str]:
+ return _GLOBAL_PARALLEL_GROUPS.get("sequence", None)
+
+
+def create_parallel_group(shards: int) -> None:
+ if shards <= 1:
+ raise ValueError(
+ f"The number of sequence parallel shards must be larger than 1 to enable sequence parallel, but got {shards}."
+ )
+
+ device_num = get_group_size()
+ if device_num % shards != 0:
+ raise ValueError(
+ f"Total number of devices ({device_num}) must be divisible by the number of sequence parallel shards ({shards})."
+ )
+
+ rank_id = get_rank()
+ sp_group_id = rank_id // shards
+ sp_group_rank_ids = list(range(sp_group_id * shards, (sp_group_id + 1) * shards))
+ sp_group_name = f"sp_group_{sp_group_id}"
+ create_group(sp_group_name, sp_group_rank_ids)
+ set_sequence_parallel_group(sp_group_name)
diff --git a/examples/moviegen/mg/dataset/__init__.py b/examples/moviegen/mg/dataset/__init__.py
new file mode 100644
index 0000000000..d968ae874a
--- /dev/null
+++ b/examples/moviegen/mg/dataset/__init__.py
@@ -0,0 +1,2 @@
+from .buckets import bucket_split_function
+from .dataset import ImageVideoDataset
diff --git a/examples/moviegen/mg/dataset/buckets.py b/examples/moviegen/mg/dataset/buckets.py
new file mode 100644
index 0000000000..e8d4970f36
--- /dev/null
+++ b/examples/moviegen/mg/dataset/buckets.py
@@ -0,0 +1,13 @@
+from typing import Callable, List, Tuple
+
+import numpy as np
+
+
+def bucket_split_function(
+ image_batch_size: int, video_batch_size: int
+) -> Tuple[Callable[[np.ndarray], int], List[int], List[int]]:
+ return (
+ lambda x: int(x.shape[0] > 1), # image or video
+ [1], # 2 buckets for now: image and videos of fixed length
+ [image_batch_size, video_batch_size],
+ )
diff --git a/examples/moviegen/mg/dataset/dataset.py b/examples/moviegen/mg/dataset/dataset.py
new file mode 100644
index 0000000000..9019aff836
--- /dev/null
+++ b/examples/moviegen/mg/dataset/dataset.py
@@ -0,0 +1,263 @@
+import csv
+import logging
+import os
+import random
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+from tqdm import tqdm
+
+from mindone.data import BaseDataset
+from mindone.data.video_reader import VideoReader
+
+from .transforms import ResizeCrop
+
+_logger = logging.getLogger(__name__)
+
+
+IMAGE_EXT = (".jpg", ".jpeg", ".png", ".gif", ".webp")
+
+
+class ImageVideoDataset(BaseDataset):
+ def __init__(
+ self,
+ csv_path: str,
+ video_folder: str,
+ text_emb_folder: Optional[Union[str, Dict[str, str]]] = None,
+ empty_text_emb: Optional[Union[str, Dict[str, str]]] = None,
+ text_drop_prob: float = 0.2,
+ tae_latent_folder: Optional[str] = None,
+ tae_scale_factor: float = 1.5305,
+ tae_shift_factor: float = 0.0609,
+ target_size: Optional[Tuple[int, int]] = None,
+ sample_n_frames: int = 17,
+ sample_stride: int = 1,
+ frames_mask_generator: Optional[Callable[[int], np.ndarray]] = None,
+ t_compress_func: Optional[Callable[[int], int]] = None,
+ filter_data: bool = False,
+ apply_transforms_dataset: bool = False,
+ *,
+ output_columns: List[str],
+ ):
+ if text_emb_folder is None:
+ raise NotImplementedError(
+ "Text embedding during training is not supported, please provide `text_emb_folder`."
+ )
+
+ self._data = self._read_data(video_folder, csv_path, text_emb_folder, tae_latent_folder, filter_data)
+ self._frames = sample_n_frames
+ self._stride = sample_stride
+ self._min_length = (self._frames - 1) * self._stride + 1
+
+ self._text_emb_folder = text_emb_folder
+ self._empty_text_emb = empty_text_emb if text_drop_prob > 0 else None
+ if self._empty_text_emb:
+ if isinstance(self._empty_text_emb, str):
+ assert os.path.exists(self._empty_text_emb), f"Empty text embedding not found: {self._empty_text_emb}"
+ else:
+ for path in self._empty_text_emb.values():
+ assert os.path.exists(path), f"Empty text embedding not found: {path}"
+ self._text_drop_prob = text_drop_prob
+
+ self._tae_latent_folder = tae_latent_folder
+ self._tae_scale_factor = tae_scale_factor
+ self._tae_shift_factor = tae_shift_factor
+ self._fmask_gen = frames_mask_generator
+ self._t_compress_func = t_compress_func or (lambda x: x)
+
+ self.output_columns = output_columns
+
+ self._transforms = (
+ self.train_transforms(target_size, interpolation=cv2.INTER_AREA) if apply_transforms_dataset else None
+ )
+
+ # prepare replacement data in case the loading of a sample fails
+ self._prev_ok_sample = self._get_replacement()
+ self._require_update_prev = False
+
+ @staticmethod
+ def _read_data(
+ data_dir: str,
+ csv_path: str,
+ text_emb_folder: Optional[Union[str, Dict[str, str]]] = None,
+ tae_latent_folder: Optional[str] = None,
+ filter_data: bool = False,
+ ) -> List[dict]:
+ def _filter_data(sample_):
+ if not os.path.isfile(sample_["video"]):
+ _logger.warning(f"Video not found: {sample_['video']}")
+ return None
+ if "text_emb" in sample_:
+ if isinstance(sample_["text_emb"], str) and not os.path.isfile(sample_["text_emb"]):
+ _logger.warning(f"Text embedding not found: {sample_['text_emb']}")
+ return None
+ else:
+ for name, path in sample_["text_emb"].items():
+ if not os.path.isfile(sample_["text_emb"][name]):
+ _logger.warning(f"Text embedding not found: {sample_['text_emb'][name]}")
+ return None
+ if "tae_latent" in sample_ and not os.path.isfile(sample_["tae_latent"]):
+ _logger.warning(f"Text embedding not found: {sample_['tae_latent']}")
+ return None
+ return sample_
+
+ with open(csv_path, "r", encoding="utf-8") as csv_file:
+ try:
+ data = []
+ for item in csv.DictReader(csv_file):
+ sample = {**item, "video": os.path.join(data_dir, item["video"])}
+ if text_emb_folder:
+ if isinstance(text_emb_folder, str):
+ sample["text_emb"] = os.path.join(text_emb_folder, Path(item["video"]).with_suffix(".npz"))
+ else:
+ sample["text_emb"] = {
+ name: os.path.join(path, Path(item["video"]).with_suffix(".npz"))
+ for name, path in text_emb_folder.items()
+ }
+ if tae_latent_folder:
+ sample["tae_latent"] = os.path.join(tae_latent_folder, Path(item["video"]).with_suffix(".npz"))
+ data.append(sample)
+ except KeyError as e:
+ _logger.error(f"CSV file requires `video` (file paths) column, but got {list(item.keys())}")
+ raise e
+
+ if filter_data:
+ with ThreadPoolExecutor(max_workers=10) as executor:
+ data = [
+ item
+ for item in tqdm(executor.map(_filter_data, data), total=len(data), desc="Filtering data")
+ if item is not None
+ ]
+
+ _logger.info(f"Number of data samples: {len(data)}")
+ return data
+
+ def _get_replacement(self, max_attempts: int = 100) -> Tuple[np.ndarray, ...]:
+ attempts, error = min(max_attempts, len(self)), None
+ for idx in range(attempts):
+ try:
+ return self._get_item(idx)
+ except Exception as e:
+ error = e
+ _logger.debug(f"Failed to load a replacement sample: {repr(e)}")
+
+ raise RuntimeError(f"Fail to load a replacement sample in {attempts} attempts. Error: {repr(error)}")
+
+ def _get_item(self, idx: int, thw: Optional[Tuple[int, int, int]] = None) -> Tuple[np.ndarray, ...]:
+ data = self._data[idx].copy()
+ num_frames = self._frames
+
+ if self._text_emb_folder:
+ if self._empty_text_emb and random.random() <= self._text_drop_prob:
+ data["text_emb"] = self._empty_text_emb
+
+ if isinstance(data["text_emb"], str):
+ with np.load(data["text_emb"]) as td:
+ data.update({"caption": td["text_emb"], "mask": td["mask"]})
+ else:
+ for enc_name, path in data["text_emb"].items():
+ with np.load(path) as td:
+ data.update({enc_name + "_caption": td["text_emb"], enc_name + "_mask": td["mask"]})
+
+ if self._tae_latent_folder:
+ tae_latent_data = np.load(data["tae_latent"])
+ latent_mean, latent_std = tae_latent_data["latent_mean"], tae_latent_data["latent_std"] # C T H W
+ if latent_mean.shape[1] < self._min_length: # TODO: add support for images and buckets
+ raise ValueError(f"Video is too short: {data['video']}")
+
+ start_pos = random.randint(0, len(latent_mean) - self._min_length)
+ batch_index = np.linspace(start_pos, start_pos + self._min_length - 1, num_frames, dtype=int)
+
+ latent_mean, latent_std = latent_mean[batch_index], latent_std[batch_index]
+ tae_latent = np.random.normal(latent_mean, latent_std).astype(np.float32)
+ tae_latent = (tae_latent - self._tae_shift_factor) * self._tae_scale_factor
+ # FIXME: remove unnecessary transpose
+ data["video"] = np.transpose(tae_latent, (1, 0, 2, 3)) # C T H W -> T C H W
+
+ else:
+ if data["video"].lower().endswith(IMAGE_EXT):
+ num_frames = 1
+ data["fps"] = np.array(120, dtype=np.float32) # FIXME: extract as IMG_FPS
+ data["video"] = cv2.cvtColor(cv2.imread(data["video"]), cv2.COLOR_BGR2RGB)
+ else:
+ with VideoReader(data["video"]) as reader:
+ min_length = self._min_length
+ if thw is not None:
+ num_frames, *data["size"] = thw
+ min_length = (num_frames - 1) * self._stride + 1
+ if len(reader) < min_length:
+ raise ValueError(f"Video is too short: {data['video']}")
+ start_pos = random.randint(0, len(reader) - min_length)
+ data["video"] = reader.fetch_frames(num=num_frames, start_pos=start_pos, step=self._stride)
+ data["fps"] = np.array(reader.fps / self._stride, dtype=np.float32)
+
+ data["num_frames"] = np.array(num_frames, dtype=np.float32)
+
+ if self._fmask_gen is not None:
+ # return frames mask with respect to the TAE's latent temporal compression
+ data["frames_mask"] = self._fmask_gen(self._t_compress_func(num_frames))
+
+ if self._transforms:
+ data = self._apply_transforms(data)
+
+ return tuple(data[c] for c in self.output_columns)
+
+ def get_bucket(self, thw: Tuple[int, int, int], sample_ids: List[int]) -> Tuple[np.ndarray, ...]:
+ batch = [self._get_item(sample_id, thw) for sample_id in sample_ids]
+ return tuple(np.stack(item) for item in map(list, zip(*batch)))
+
+ def __getitem__(self, idx: int) -> Tuple[np.ndarray, ...]:
+ try:
+ sample = self._get_item(idx)
+ if self._require_update_prev:
+ self._prev_ok_sample = sample
+ self._require_update_prev = False
+ except Exception as e:
+ _logger.warning(f"Failed to fetch sample #{idx}, the video will be replaced. Error: {e}")
+ sample = self._prev_ok_sample
+ self._require_update_prev = True
+
+ return sample
+
+ def __len__(self):
+ return len(self._data)
+
+ def _apply_transforms(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+ for transform in self._transforms:
+ input_data = tuple(data[column] for column in transform["input_columns"])
+ for op in transform["operations"]:
+ input_data = op(*input_data)
+ if not isinstance(input_data, tuple): # wrap numpy array in a tuple
+ input_data = (input_data,)
+ data.update(zip(transform.get("output_columns", transform["input_columns"]), input_data))
+ return data
+
+ def train_transforms(
+ self,
+ target_size: Tuple[int, int],
+ interpolation: int = cv2.INTER_LINEAR,
+ tokenizer: Optional[Callable[[str], np.ndarray]] = None,
+ ) -> List[dict]:
+ transforms = []
+ if not self._tae_latent_folder:
+ transforms.append(
+ {
+ "operations": [
+ ResizeCrop(target_size, interpolation=interpolation),
+ lambda x: x.astype(np.float32) / 127.5 - 1,
+ lambda x: x[None, ...] if x.ndim == 3 else x, # if image
+ lambda x: np.transpose(x, (0, 3, 1, 2)),
+ ],
+ "input_columns": ["video"],
+ }
+ )
+
+ if "caption" in self.output_columns and not self._text_emb_folder:
+ if tokenizer is None:
+ raise RuntimeError("Please provide a tokenizer for text data in `train_transforms()`.")
+ transforms.append({"operations": [tokenizer], "input_columns": ["caption"]})
+
+ return transforms
diff --git a/examples/moviegen/mg/datasets/tae_dataset.py b/examples/moviegen/mg/dataset/tae_dataset.py
similarity index 56%
rename from examples/moviegen/mg/datasets/tae_dataset.py
rename to examples/moviegen/mg/dataset/tae_dataset.py
index e09310460a..7be42dbd48 100644
--- a/examples/moviegen/mg/datasets/tae_dataset.py
+++ b/examples/moviegen/mg/dataset/tae_dataset.py
@@ -1,17 +1,19 @@
import copy
import csv
-import glob
import logging
import os
import random
+from pathlib import Path
+from typing import Dict, List, Literal, Optional, Tuple, Union
-import albumentations
import cv2
import imageio
import numpy as np
from decord import VideoReader
-import mindspore as ms
+from mindone.data import BaseDataset
+
+__all__ = ["VideoDataset", "BatchTransform"]
logger = logging.getLogger()
@@ -20,21 +22,25 @@ def create_video_transforms(
size=384, crop_size=256, interpolation="bicubic", backend="al", random_crop=False, flip=False, num_frames=None
):
if backend == "al":
+ os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" # prevent albumentations from being annoying
# expect rgb image in range 0-255, shape (h w c)
- from albumentations import CenterCrop, HorizontalFlip, RandomCrop, SmallestMaxSize
+ from albumentations import CenterCrop, Compose, HorizontalFlip, RandomCrop, SmallestMaxSize
+
+ if isinstance(crop_size, int):
+ crop_size = (crop_size, crop_size)
# NOTE: to ensure augment all frames in a video in the same way.
assert num_frames is not None, "num_frames must be parsed"
- targets = {"image{}".format(i): "image" for i in range(num_frames)}
+ targets = {f"image{i}": "image" for i in range(num_frames)}
mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC}
transforms = [
SmallestMaxSize(max_size=size, interpolation=mapping[interpolation]),
- CenterCrop(crop_size, crop_size) if not random_crop else RandomCrop(crop_size, crop_size),
+ CenterCrop(*crop_size) if not random_crop else RandomCrop(*crop_size),
]
if flip:
transforms += [HorizontalFlip(p=0.5)]
- pixel_transforms = albumentations.Compose(
+ pixel_transforms = Compose(
transforms,
additional_targets=targets,
)
@@ -44,29 +50,40 @@ def create_video_transforms(
return pixel_transforms
-def get_video_path_list(folder):
- # TODO: find recursively
- fmts = ["avi", "mp4", "gif"]
- out = []
- for fmt in fmts:
- out += glob.glob(os.path.join(folder, f"*.{fmt}"))
- return sorted(out)
+def get_video_path_list(folder: str, video_column: str) -> List[Dict[str, str]]:
+ """
+ Constructs a list of images and videos in the given directory (recursively).
+
+ Args:
+ folder: path to a directory containing images and videos.
+ video_column: name of the column to store video paths.
+ Returns:
+ A list of paths to images and videos in the given directory (absolute and relative).
+ """
+ exts = (".jpg", ".jpeg", ".png", ".gif", ".mp4", ".avi")
+ data = [
+ {video_column: str(item), "rel_path": str(item.relative_to(folder))}
+ for item in Path(folder).rglob("*")
+ if (item.is_file() and item.suffix.lower() in exts)
+ ]
+ return sorted(data, key=lambda x: x[video_column])
-class VideoDataset:
+class VideoDataset(BaseDataset):
def __init__(
self,
- csv_path=None,
- data_folder=None,
- size=384,
- crop_size=256,
- random_crop=False,
- flip=False,
- sample_stride=4,
- sample_n_frames=16,
- return_image=False,
- transform_backend="al",
- video_column="video",
+ csv_path: Optional[str],
+ folder: str,
+ size: int = 384,
+ crop_size: Union[int, Tuple[int, int]] = 256,
+ random_crop: bool = False,
+ flip: bool = False,
+ sample_stride: int = 1,
+ sample_n_frames: int = 16,
+ return_image: bool = False,
+ video_column: str = "video",
+ *,
+ output_columns: List[str],
):
"""
size: image resize size
@@ -76,17 +93,18 @@ def __init__(
if csv_path is not None:
with open(csv_path, "r") as csvfile:
- self.dataset = list(csv.DictReader(csvfile))
- self.read_from_csv = True
+ self.dataset = [
+ {**item, video_column: os.path.join(folder, item[video_column]), "rel_path": item[video_column]}
+ for item in csv.DictReader(csvfile)
+ ]
else:
- self.dataset = get_video_path_list(data_folder)
- self.read_from_csv = False
+ self.dataset = get_video_path_list(folder, video_column)
self.length = len(self.dataset)
logger.info(f"Num data samples: {self.length}")
logger.info(f"sample_n_frames: {sample_n_frames}")
- self.data_folder = data_folder
+ self.folder = folder
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.return_image = return_image
@@ -98,8 +116,8 @@ def __init__(
flip=flip,
num_frames=sample_n_frames,
)
- self.transform_backend = transform_backend
self.video_column = video_column
+ self.output_columns = output_columns
# prepare replacement data
max_attempts = 100
@@ -123,12 +141,8 @@ def get_replace_data(self, max_attempts=100):
def get_batch(self, idx):
# get video raw pixels (batch of frame) and its caption
- if self.read_from_csv:
- video_dict = self.dataset[idx]
- video_fn = video_dict[list(video_dict.keys())[0]]
- video_path = os.path.join(self.data_folder, video_fn)
- else:
- video_path = self.dataset[idx]
+ video_dict = self.dataset[idx].copy()
+ video_path = video_dict[self.video_column]
video_reader = VideoReader(video_path)
@@ -142,13 +156,13 @@ def get_batch(self, idx):
batch_index = [random.randint(0, video_length - 1)]
if video_path.endswith(".gif"):
- pixel_values = video_reader[batch_index] # shape: (f, h, w, c)
+ video_dict[self.video_column] = video_reader[batch_index] # shape: (f, h, w, c)
else:
- pixel_values = video_reader.get_batch(batch_index).asnumpy() # shape: (f, h, w, c)
+ video_dict[self.video_column] = video_reader.get_batch(batch_index).asnumpy() # shape: (f, h, w, c)
del video_reader
- return pixel_values
+ return tuple(video_dict[c] for c in self.output_columns)
def __len__(self):
return self.length
@@ -159,44 +173,47 @@ def __getitem__(self, idx):
video: preprocessed video frames in shape (f, c, h, w), normalized to [-1, 1]
"""
try:
- pixel_values = self.get_batch(idx)
+ data = self.get_batch(idx)
if (self.prev_ok_sample is None) or (self.require_update_prev):
- self.prev_ok_sample = copy.deepcopy(pixel_values)
+ self.prev_ok_sample = copy.deepcopy(data)
self.require_update_prev = False
except Exception as e:
logger.warning(f"Fail to get sample of idx {idx}. The corrupted video will be replaced.")
print("\tError msg: {}".format(e), flush=True)
assert self.prev_ok_sample is not None
- pixel_values = self.prev_ok_sample # unless the first sample is already not ok
+ data = self.prev_ok_sample # unless the first sample is already not ok
self.require_update_prev = True
if idx >= self.length:
raise IndexError # needed for checking the end of dataset iteration
+ pixel_values = data[0]
num_frames = len(pixel_values)
# pixel value: (f, h, w, 3) -> transforms -> (f 3 h' w')
- if self.transform_backend == "al":
- # NOTE:it's to ensure augment all frames in a video in the same way.
- # ref: https://albumentations.ai/docs/examples/example_multi_target/
+ # NOTE:it's to ensure augment all frames in a video in the same way.
+ # ref: https://albumentations.ai/docs/examples/example_multi_target/
- inputs = {"image": pixel_values[0]}
- for i in range(num_frames - 1):
- inputs[f"image{i}"] = pixel_values[i + 1]
+ inputs = {"image": pixel_values[0]}
+ for i in range(num_frames - 1):
+ inputs[f"image{i}"] = pixel_values[i + 1]
- output = self.pixel_transforms(**inputs)
+ output = self.pixel_transforms(**inputs)
- pixel_values = np.stack(list(output.values()), axis=0)
- # (t h w c) -> (c t h w)
- pixel_values = np.transpose(pixel_values, (3, 0, 1, 2))
- else:
- raise NotImplementedError
+ pixel_values = np.stack(list(output.values()), axis=0)
+ # (t h w c) -> (c t h w)
+ pixel_values = np.transpose(pixel_values, (3, 0, 1, 2))
if self.return_image:
pixel_values = pixel_values[1]
pixel_values = (pixel_values / 127.5 - 1.0).astype(np.float32)
- return pixel_values
+ return pixel_values, *data[1:]
+
+ @staticmethod
+ def train_transforms(**kwargs) -> List[dict]:
+ # train transforms are performed during data reading
+ pass
# TODO: parse in config dict
@@ -214,87 +231,48 @@ def check_sanity(x, save_fp="./tmp.gif"):
class BatchTransform:
- def __init__(self, mixed_strategy, mixed_image_ratio=0.2):
- self.mixed_strategy = mixed_strategy
+ def __init__(
+ self,
+ mixed_strategy: Literal["mixed_video_image", "mixed_video_random", "image_only"],
+ mixed_image_ratio: float = 0.2,
+ ):
+ if mixed_strategy == "mixed_video_image":
+ self._trans_fn = self._mixed_video_image
+ elif mixed_strategy == "mixed_video_random":
+ self._trans_fn = self._mixed_video_random
+ elif mixed_strategy == "image_only":
+ self._trans_fn = self._image_only
+ else:
+ raise NotImplementedError(f"Unknown mixed_strategy: {mixed_strategy}")
self.mixed_image_ratio = mixed_image_ratio
- def __call__(self, x):
- # x: (bs, c, t, h, w)
- if self.mixed_strategy == "mixed_video_image":
- if random.random() < self.mixed_image_ratio:
- x = x[:, :, :1, :, :]
- elif self.mixed_strategy == "mixed_video_random":
- # TODO: somehow it's slow. consider do it with tensor in NetWithLoss
- length = random.randint(1, x.shape[2])
- x = x[:, :, :length, :, :]
- elif self.mixed_strategy == "image_only":
+ def _mixed_video_image(self, x: np.ndarray) -> np.ndarray:
+ if random.random() < self.mixed_image_ratio:
x = x[:, :, :1, :, :]
- else:
- raise ValueError
return x
+ @staticmethod
+ def _mixed_video_random(x: np.ndarray) -> np.ndarray:
+ # TODO: somehow it's slow. consider do it with tensor in NetWithLoss
+ length = random.randint(1, x.shape[2])
+ return x[:, :, :length, :, :]
-def create_dataloader(
- ds_config,
- batch_size,
- mixed_strategy=None,
- mixed_image_ratio=0.0,
- num_parallel_workers=12,
- max_rowsize=32,
- shuffle=True,
- device_num=1,
- rank_id=0,
- drop_remainder=True,
-):
- """
- Args:
- mixed_strategy:
- None - all output batches are videoes [bs, c, T, h, w]
- mixed_video_image - with prob of mixed_image_ratio, output batch are images [b, c, 1, h, w]
- mixed_video_random - output batch has a random number of frames [bs, c, t, h, w], t is the same of samples in a batch
- mixed_image_ratio:
- ds_config, dataset config, args for ImageDataset or VideoDataset
- ds_name: dataset name, image or video
- """
- dataset = VideoDataset(**ds_config)
- print("Total number of samples: ", len(dataset))
-
- # Larger value leads to more memory consumption. Default: 16
- # prefetch_size = config.get("prefetch_size", 16)
- # ms.dataset.config.set_prefetch_size(prefetch_size)
-
- dataloader = ms.dataset.GeneratorDataset(
- source=dataset,
- column_names=["video"],
- num_shards=device_num,
- shard_id=rank_id,
- python_multiprocessing=True,
- shuffle=shuffle,
- num_parallel_workers=num_parallel_workers,
- max_rowsize=max_rowsize,
- )
-
- dl = dataloader.batch(
- batch_size,
- drop_remainder=drop_remainder,
- )
-
- if mixed_strategy is not None:
- batch_map_fn = BatchTransform(mixed_strategy, mixed_image_ratio)
- dl = dl.map(
- operations=batch_map_fn,
- input_columns=["video"],
- num_parallel_workers=1,
- )
+ @staticmethod
+ def _image_only(x: np.ndarray) -> np.ndarray:
+ return x[:, :, :1, :, :]
- return dl
+ def __call__(self, x):
+ # x: (bs, c, t, h, w)
+ return self._trans_fn(x)
if __name__ == "__main__":
+ from mindone.data import create_dataloader
+
test = "dl"
if test == "dataset":
ds_config = dict(
- data_folder="../videocomposer/datasets/webvid5",
+ folder="../videocomposer/datasets/webvid5",
random_crop=True,
flip=True,
)
@@ -312,19 +290,16 @@ def create_dataloader(
ds_config = dict(
csv_path="../videocomposer/datasets/webvid5_copy.csv",
- data_folder="../videocomposer/datasets/webvid5",
+ folder="../videocomposer/datasets/webvid5",
sample_n_frames=17,
size=128,
crop_size=128,
)
+ ds = VideoDataset(**ds_config)
+ bt = BatchTransform(mixed_strategy="mixed_video_random", mixed_image_ratio=0.2)
# test loader
- dl = create_dataloader(
- ds_config,
- 4,
- mixed_strategy="mixed_video_random",
- mixed_image_ratio=0.2,
- )
+ dl = create_dataloader(ds, batch_size=4, batch_transforms={"operations": bt, "input_columns": ["video"]})
num_batches = dl.get_dataset_size()
# ms.set_context(mode=0)
diff --git a/examples/moviegen/mg/dataset/transforms.py b/examples/moviegen/mg/dataset/transforms.py
new file mode 100644
index 0000000000..27a6263c81
--- /dev/null
+++ b/examples/moviegen/mg/dataset/transforms.py
@@ -0,0 +1,46 @@
+from typing import Optional, Tuple
+
+import cv2
+import numpy as np
+
+
+class ResizeCrop:
+ """
+ Resize and center crop the input image or video to a target size while preserving the aspect ratio.
+
+ Args:
+ size (Optional[Tuple[int, int]], optional): The target size. If None, the target size should be passed during the call.
+ interpolation (cv2.InterpolationFlags, optional): The interpolation method. Defaults to cv2.INTER_LINEAR.
+ preserve_orientation (bool, optional): Whether to preserve the orientation of the image/video. Defaults to True.
+ """
+
+ def __init__(
+ self,
+ size: Optional[Tuple[int, int]] = None,
+ interpolation: int = cv2.INTER_LINEAR,
+ preserve_orientation: bool = True,
+ ):
+ self._size = size
+ self._inter = interpolation
+ self._po = preserve_orientation
+
+ def __call__(self, x: np.ndarray, size: Optional[Tuple[int, int]] = None) -> np.ndarray:
+ h, w = x.shape[-3:-1] # support images and videos
+ th, tw = size or self._size
+
+ scale = max(th / h, tw / w)
+ if self._po and (new_scale := max(th / w, tw / h)) < scale: # preserve orientation
+ scale = new_scale
+ th, tw = tw, th
+
+ if scale != 1: # resize
+ if x.ndim == 3: # if image
+ x = cv2.resize(x, None, fx=scale, fy=scale, interpolation=self._inter)
+ else: # if video
+ x = np.array([cv2.resize(i, None, fx=scale, fy=scale, interpolation=self._inter) for i in x])
+
+ if x.shape[-3:-1] != (th, tw): # center crop
+ i, j = round((x.shape[-3] - th) / 2.0), round((x.shape[-2] - tw) / 2.0)
+ x = x[..., i : i + th, j : j + tw, :]
+
+ return x
diff --git a/examples/moviegen/mg/models/__init__.py b/examples/moviegen/mg/models/__init__.py
new file mode 100644
index 0000000000..0290c39540
--- /dev/null
+++ b/examples/moviegen/mg/models/__init__.py
@@ -0,0 +1,3 @@
+from .llama import *
+from .tae import TemporalAutoencoder
+from .text_encoders import *
diff --git a/examples/moviegen/mg/models/llama/__init__.py b/examples/moviegen/mg/models/llama/__init__.py
new file mode 100644
index 0000000000..6cf34ce83b
--- /dev/null
+++ b/examples/moviegen/mg/models/llama/__init__.py
@@ -0,0 +1 @@
+from .network import *
diff --git a/examples/moviegen/mg/models/llama/activation.py b/examples/moviegen/mg/models/llama/activation.py
new file mode 100644
index 0000000000..7b54d885a1
--- /dev/null
+++ b/examples/moviegen/mg/models/llama/activation.py
@@ -0,0 +1,28 @@
+import logging
+from collections import OrderedDict
+
+import mindspore.mint as mint
+import mindspore.nn as nn
+from mindspore import Tensor
+
+logger = logging.getLogger(__name__)
+
+
+class QuickGELU(nn.Cell):
+ def construct(self, x: Tensor):
+ return x * mint.sigmoid(1.702 * x)
+
+
+class ClassInstantier(OrderedDict):
+ def __getitem__(self, key):
+ content = super().__getitem__(key)
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
+ return cls(**kwargs)
+
+
+ACT2CLS = {
+ "quick_gelu": QuickGELU,
+ "gelu": nn.GELU,
+ "silu": nn.SiLU,
+}
+ACT2FN = ClassInstantier(ACT2CLS)
diff --git a/examples/moviegen/mg/models/llama/block.py b/examples/moviegen/mg/models/llama/block.py
new file mode 100644
index 0000000000..6ec990bfbf
--- /dev/null
+++ b/examples/moviegen/mg/models/llama/block.py
@@ -0,0 +1,297 @@
+from typing import Optional, Sequence, Tuple, Union
+
+import numpy as np
+
+import mindspore as ms
+import mindspore.mint as mint
+import mindspore.mint.nn.functional as F
+import mindspore.nn as nn
+import mindspore.ops as ops
+from mindspore import Parameter, Tensor
+from mindspore.communication import get_group_size
+from mindspore.ops.operations.nn_ops import FlashAttentionScore
+
+from ...acceleration import get_sequence_parallel_group
+from .activation import ACT2FN
+
+
+class LlamaRMSNorm(nn.Cell):
+ def __init__(self, hidden_size: Union[int, Sequence[int]], eps: float = 1e-6):
+ super().__init__()
+ self.weight = Parameter(np.ones(hidden_size).astype(np.float32)) # keep normalization at FP32
+ self.variance_epsilon = eps
+
+ def construct(self, hidden_states: Tensor) -> Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states, _ = ops.rms_norm(hidden_states.to(ms.float32), self.weight, epsilon=self.variance_epsilon)
+ return hidden_states.to(input_dtype)
+
+
+class LlamaMLP(nn.Cell):
+ def __init__(
+ self,
+ intermediate_size: int = 8192,
+ hidden_size: int = 3072,
+ hidden_act: str = "silu",
+ dtype: ms.Type = ms.float32,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.gate_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False, dtype=dtype)
+ self.up_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False, dtype=dtype)
+ self.down_proj = mint.nn.Linear(self.intermediate_size, self.hidden_size, bias=False, dtype=dtype)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def construct(self, hidden_state: Tensor) -> Tensor:
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor:
+ if n_rep == 1:
+ return hidden_states
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ hidden_states = hidden_states[:, :, None, :, :]
+ hidden_states = mint.broadcast_to(hidden_states, (batch, num_key_value_heads, n_rep, slen, head_dim))
+ hidden_states = ops.reshape(hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim))
+ return hidden_states
+
+
+class LlamaAttention(nn.Cell):
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ num_attention_heads: int = 32,
+ num_key_value_heads: int = 8,
+ attention_dropout: float = 0.0,
+ attention_bias: bool = False,
+ dtype: ms.Type = ms.float32,
+ ) -> None:
+ super().__init__()
+
+ self.attention_dropout = attention_dropout
+ self.hidden_size = hidden_size
+ self.num_heads = num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = mint.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias, dtype=dtype)
+ self.k_proj = mint.nn.Linear(
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype
+ )
+ self.v_proj = mint.nn.Linear(
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype
+ )
+ self.o_proj = mint.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=attention_bias, dtype=dtype)
+
+ if (sp_group := get_sequence_parallel_group()) is not None:
+ self.sp_group_size = get_group_size(sp_group)
+ self.alltoall = ops.AlltoAll(self.sp_group_size, 1, 2, group=sp_group)
+ else:
+ self.sp_group_size = None
+ self.alltoall = nn.Identity()
+
+ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor:
+ bsz, q_len, _ = hidden_states.shape
+
+ kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(kv_hidden_states)
+ value_states = self.v_proj(kv_hidden_states)
+
+ query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim))
+ query_states = mint.permute(query_states, (0, 2, 1, 3))
+ query_states = self.alltoall(query_states)
+
+ key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim))
+ key_states = mint.permute(key_states, (0, 2, 1, 3))
+ key_states = self.alltoall(key_states)
+
+ value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim))
+ value_states = mint.permute(value_states, (0, 2, 1, 3))
+ value_states = self.alltoall(value_states)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ key_states = mint.permute(key_states, (0, 1, 3, 2))
+ attn_weights = mint.matmul(query_states, key_states) / mint.sqrt(Tensor(self.head_dim))
+
+ # upcast attention to fp32
+ attn_weights = attn_weights.to(ms.float32)
+ attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype)
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = mint.matmul(attn_weights, value_states)
+
+ attn_output = mint.permute(attn_output, (0, 2, 1, 3))
+ attn_output = self.alltoall(attn_output)
+ attn_output = ops.reshape(attn_output, (bsz, q_len, -1))
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output
+
+
+class LlamaFlashAttention(LlamaAttention):
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ num_attention_heads: int = 32,
+ num_key_value_heads: int = 8,
+ attention_dropout: float = 0.0,
+ attention_bias: bool = False,
+ dtype: ms.Type = ms.float32,
+ ) -> None:
+ super().__init__(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ attention_dropout=attention_dropout,
+ attention_bias=attention_bias,
+ dtype=dtype,
+ )
+ num_heads = self.num_heads // self.sp_group_size if self.sp_group_size is not None else self.num_heads
+ self.flash_attention = FlashAttentionScore(
+ num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND"
+ )
+
+ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor:
+ bsz, q_len, _ = hidden_states.shape
+
+ kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(kv_hidden_states)
+ value_states = self.v_proj(kv_hidden_states)
+
+ query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim))
+ query_states = mint.permute(query_states, (0, 2, 1, 3))
+ query_states = self.alltoall(query_states)
+
+ key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim))
+ key_states = mint.permute(key_states, (0, 2, 1, 3))
+ key_states = self.alltoall(key_states)
+
+ value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim))
+ value_states = mint.permute(value_states, (0, 2, 1, 3))
+ value_states = self.alltoall(value_states)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # Reshape to the expected shape and dtype for Flash Attention
+ query_states = mint.permute(query_states, (0, 2, 1, 3))
+ key_states = mint.permute(key_states, (0, 2, 1, 3))
+ value_states = mint.permute(value_states, (0, 2, 1, 3))
+
+ _, _, _, attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, None)
+ attn_output = self.alltoall(attn_output)
+ attn_output = ops.reshape(attn_output, (bsz, q_len, -1))
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output
+
+
+class PatchEmbed3D(nn.Cell):
+ def __init__(
+ self,
+ patch_size: Tuple[int, int, int] = (1, 2, 2),
+ in_channels: int = 8,
+ hidden_size: int = 4096,
+ dtype: ms.Type = ms.float32,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.proj = nn.Conv3d(
+ in_channels,
+ hidden_size,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ pad_mode="pad",
+ has_bias=False,
+ dtype=dtype,
+ )
+
+ def construct(self, x: Tensor) -> Tensor:
+ _, t, _, h, w = x.shape
+ # assert t % self.patch_size[0] == 0
+ # assert h % self.patch_size[1] == 0
+ # assert w % self.patch_size[2] == 0
+
+ x = mint.permute(x, (0, 2, 1, 3, 4))
+ x = self.proj(x) # (B C T H W)
+ x = mint.flatten(x, start_dim=2)
+ x = mint.permute(x, (0, 2, 1))
+ return x
+
+
+class LinearPatchEmbed3D(nn.Cell):
+ def __init__(
+ self,
+ patch_size: Tuple[int, int, int] = (1, 2, 2),
+ in_channels: int = 8,
+ hidden_size: int = 4096,
+ dtype: ms.Type = ms.float32,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.proj = mint.nn.Linear(
+ patch_size[0] * patch_size[1] * patch_size[2] * in_channels, hidden_size, bias=False, dtype=dtype
+ )
+
+ def construct(self, x: Tensor) -> Tensor:
+ b, t, c, h, w = x.shape
+ # assert t % self.patch_size[0] == 0
+ # assert h % self.patch_size[1] == 0
+ # assert w % self.patch_size[2] == 0
+
+ p0, p1, p2 = self.patch_size[0], self.patch_size[1], self.patch_size[2]
+ nt, nh, nw = t // p0, h // p1, w // p2
+ x = ops.reshape(x, (b, nt, p0, c, nh, p1, nw, p2))
+ x = mint.permute(x, (0, 1, 4, 6, 3, 2, 5, 7)) # (B, nt, nh, nw, c, p0, p1, p2)
+ x = ops.reshape(x, (b, nt * nh * nw, -1))
+ x = self.proj(x)
+ return x
+
+
+class TimestepEmbedder(nn.Cell):
+ def __init__(
+ self,
+ hidden_size: int,
+ frequency_embedding_size: int = 256,
+ hidden_act: str = "silu",
+ dtype: ms.Type = ms.float32,
+ ) -> None:
+ super().__init__()
+ self.mlp = nn.SequentialCell(
+ mint.nn.Linear(frequency_embedding_size, hidden_size, bias=False, dtype=dtype),
+ ACT2FN[hidden_act],
+ mint.nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+ self._dtype = dtype
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ @staticmethod
+ def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000) -> Tensor:
+ half = dim // 2
+ freqs = mint.exp(-mint.log(Tensor(max_period)) * mint.arange(start=0, end=half, dtype=ms.float32) / half)
+ args = ops.unsqueeze(t, 1).to(ms.float32) * ops.unsqueeze(freqs, 0)
+ embedding = mint.cat([mint.cos(args), mint.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = mint.cat([embedding, mint.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def construct(self, t: Tensor) -> Tensor:
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq.to(self.dtype))
+ return t_emb
diff --git a/examples/moviegen/mg/models/llama/network.py b/examples/moviegen/mg/models/llama/network.py
new file mode 100644
index 0000000000..b6474bfbe0
--- /dev/null
+++ b/examples/moviegen/mg/models/llama/network.py
@@ -0,0 +1,441 @@
+from __future__ import annotations
+
+import logging
+from typing import Literal, Optional, Tuple, Union
+
+import numpy as np
+
+from mindspore import Parameter, Tensor
+from mindspore import dtype as mstype
+from mindspore import lazy_inline, load_checkpoint, mint, nn, ops
+
+from mindone.models.utils import normal_, zeros_
+
+from ...acceleration import GatherFowardSplitBackward, SplitFowardGatherBackward, get_sequence_parallel_group
+from ..text_encoders import TextProjector
+from .activation import ACT2FN
+from .block import (
+ LinearPatchEmbed3D,
+ LlamaAttention,
+ LlamaFlashAttention,
+ LlamaMLP,
+ LlamaRMSNorm,
+ PatchEmbed3D,
+ TimestepEmbedder,
+)
+
+__all__ = ["LlamaModel", "llama3_1B", "llama3_5B", "llama3_30B"]
+
+_logger = logging.getLogger(__name__)
+
+Llama_ATTENTION_CLASSES = {
+ "eager": LlamaAttention,
+ "flash_attention": LlamaFlashAttention,
+}
+
+
+def t2i_modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
+ return x * (1 + scale) + shift
+
+
+class LlamaDecoderLayer(nn.Cell):
+ @lazy_inline(policy="front")
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ intermediate_size: int = 14336,
+ num_attention_heads: int = 32,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-5,
+ attention_dropout: float = 0.0,
+ attention_bias: bool = False,
+ hidden_act: str = "silu",
+ attn_implementation: Literal["eager", "flash_attention"] = "eager",
+ dtype: mstype = mstype.float32,
+ ) -> None:
+ super().__init__()
+
+ self.self_attn = Llama_ATTENTION_CLASSES[attn_implementation](
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ attention_dropout=attention_dropout,
+ attention_bias=attention_bias,
+ dtype=dtype,
+ )
+
+ self.cross_attn = Llama_ATTENTION_CLASSES[attn_implementation](
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ attention_dropout=attention_dropout,
+ attention_bias=attention_bias,
+ dtype=dtype,
+ )
+
+ self.mlp = LlamaMLP(
+ intermediate_size=intermediate_size, hidden_size=hidden_size, hidden_act=hidden_act, dtype=dtype
+ )
+
+ self.scale_shift_table = Parameter(Tensor(np.random.randn(1, 6, hidden_size) / hidden_size**0.5, dtype=dtype))
+ self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
+
+ def construct(
+ self,
+ hidden_states: Tensor,
+ encoder_hidden_states: Tensor,
+ modulation_parameters: Tensor,
+ position_embedding: Tensor,
+ ) -> Tensor:
+ B = hidden_states.shape[0]
+
+ # 3.1.3 Positional Embedding
+ hidden_states = hidden_states + position_embedding
+
+ # 3.1.3 Adaptive Layer Norm
+ modulation_parameters = self.scale_shift_table.to(hidden_states.dtype) + ops.reshape(
+ modulation_parameters, (B, 6, -1)
+ )
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk(modulation_parameters, 6, dim=1)
+
+ # Self-Attention (Bi-Directional Attention)
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = t2i_modulate(hidden_states, shift_msa, scale_msa)
+ hidden_states = self.self_attn(hidden_states)
+ hidden_states = gate_msa * hidden_states
+ hidden_states = residual + hidden_states
+
+ # 3.1.3 Cross Attention
+ residual = hidden_states
+ hidden_states = self.cross_attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = t2i_modulate(hidden_states, shift_mlp, scale_mlp)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = gate_mlp * hidden_states
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class LlamaFinalLayer(nn.Cell):
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ patch_size: Tuple[int, int, int] = (1, 2, 2),
+ out_channels: int = 8,
+ rms_norm_eps: float = 1e-5,
+ dtype: mstype = mstype.float32,
+ ) -> None:
+ super().__init__()
+ self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
+ self.proj = mint.nn.Linear(
+ hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False, dtype=dtype
+ )
+ self.scale_shift_table = Parameter(Tensor(np.random.randn(2, hidden_size) / hidden_size**0.5, dtype=dtype))
+
+ def construct(self, hidden_states: Tensor, timestep_embedding: Tensor):
+ shift, scale = mint.chunk(
+ ops.unsqueeze(self.scale_shift_table, 0) + ops.unsqueeze(timestep_embedding, 1), 2, dim=1
+ )
+ hidden_states = t2i_modulate(self.input_layernorm(hidden_states), shift, scale)
+ hidden_states = self.proj(hidden_states)
+ return hidden_states
+
+
+class LlamaModel(nn.Cell):
+ def __init__(
+ self,
+ in_channels: int = 8,
+ out_channels: Optional[int] = None,
+ hidden_size: int = 4096,
+ intermediate_size: int = 14336,
+ num_attention_heads: int = 32,
+ num_hidden_layers: int = 32,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-5,
+ attention_dropout: float = 0.0,
+ attention_bias: bool = False,
+ hidden_act: str = "silu",
+ initializer_range: float = 0.02,
+ patch_size: Tuple[int, int, int] = (1, 2, 2),
+ max_length: Tuple[int, int, int] = (128, 64, 64),
+ attn_implementation: Literal["eager", "flash_attention"] = "eager",
+ recompute_every_nth_block: Optional[int] = None,
+ use_linear_patch_embedder: bool = True,
+ post_init_weight: bool = True,
+ dtype: mstype.Type = mstype.float32,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.rms_norm_eps = rms_norm_eps
+ self.max_length = max_length
+ self._dtype = dtype
+
+ self.layers = nn.CellList(
+ [
+ LlamaDecoderLayer(
+ hidden_size=self.hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=self.num_attention_heads,
+ num_key_value_heads=self.num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_dropout=attention_dropout,
+ attention_bias=attention_bias,
+ hidden_act=hidden_act,
+ attn_implementation=attn_implementation,
+ dtype=dtype,
+ )
+ for _ in range(num_hidden_layers)
+ ]
+ )
+
+ self.final_layer = LlamaFinalLayer(
+ hidden_size=self.hidden_size,
+ patch_size=self.patch_size,
+ out_channels=self.out_channels,
+ rms_norm_eps=rms_norm_eps,
+ dtype=dtype,
+ )
+
+ self.pos_embedding_table_t = nn.Embedding(max_length[0], self.hidden_size, dtype=dtype)
+ self.pos_embedding_table_h = nn.Embedding(max_length[1], self.hidden_size, dtype=dtype)
+ self.pos_embedding_table_w = nn.Embedding(max_length[2], self.hidden_size, dtype=dtype)
+
+ if use_linear_patch_embedder:
+ self.latent_embedder = LinearPatchEmbed3D(self.patch_size, self.in_channels, self.hidden_size, dtype=dtype)
+ else:
+ self.latent_embedder = PatchEmbed3D(self.patch_size, self.in_channels, self.hidden_size, dtype=dtype)
+
+ self.timestep_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype)
+ self.adaLN_modulation = nn.SequentialCell(
+ ACT2FN[hidden_act], mint.nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=False, dtype=dtype)
+ )
+
+ self.text_projector = TextProjector(
+ out_features=self.hidden_size, layer_norm=LlamaRMSNorm, norm_eps=self.rms_norm_eps, dtype=dtype
+ )
+
+ # init sequence parallel
+ if (sp_group := get_sequence_parallel_group()) is not None:
+ _logger.info(f"Initialize Llama model with sequence parallel group `{sp_group}`.")
+ self.split_forward_gather_backward = SplitFowardGatherBackward(dim=1, grad_scale="down", group=sp_group)
+ self.gather_forward_split_backward = GatherFowardSplitBackward(dim=1, grad_scale="up", group=sp_group)
+ else:
+ self.split_forward_gather_backward = nn.Identity()
+ self.gather_forward_split_backward = nn.Identity()
+
+ # post-init
+ if post_init_weight:
+ self.initializer_range = initializer_range
+ self.init_weights()
+
+ if recompute_every_nth_block is not None:
+ _logger.info(f"Recomputing every {recompute_every_nth_block} block.")
+ for i, layer in enumerate(self.layers):
+ if i % recompute_every_nth_block == 0:
+ layer.recompute()
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ def init_weights(self):
+ std = self.initializer_range
+
+ def _init_weights(module):
+ if isinstance(module, mint.nn.Linear):
+ normal_(module.weight, mean=0.0, std=std)
+ if module.bias is not None:
+ zeros_(module.weight)
+ elif isinstance(module, nn.Embedding):
+ normal_(module.embedding_table, mean=0.0, std=std)
+
+ self.apply(_init_weights)
+
+ # Initialize patch_embed like nn.Dense (instead of nn.Conv3d):
+ normal_(self.latent_embedder.proj.weight, mean=0.0, std=std)
+ if self.latent_embedder.proj.bias is not None:
+ zeros_(self.latent_embedder.proj.bias)
+
+ # Zero-out adaLN modulation block:
+ zeros_(self.adaLN_modulation[-1].weight)
+ if self.adaLN_modulation[-1].bias is not None:
+ zeros_(self.adaLN_modulation[-1].bias)
+
+ # Zero-out final block as DiT does
+ zeros_(self.final_layer.proj.weight)
+ if self.final_layer.proj.bias is not None:
+ zeros_(self.final_layer.proj.bias)
+
+ def learnable_position_embedding(self, latent_embedding: Tensor) -> Tensor:
+ # 3.1.3
+ _, t, _, h, w = latent_embedding.shape
+ p0, p1, p2 = self.patch_size[0], self.patch_size[1], self.patch_size[2]
+ nt, nh, nw = t // p0, h // p1, w // p2
+
+ # assert nt < self.max_length[0]
+ # assert nh < self.max_length[1]
+ # assert nw < self.max_length[2]
+
+ t_inds = mint.arange(nt, dtype=mstype.int64)
+ h_inds = mint.arange(nh, dtype=mstype.int64)
+ w_inds = mint.arange(nw, dtype=mstype.int64)
+
+ position_ids = ops.meshgrid(t_inds, h_inds, w_inds, indexing="ij")
+ position_ids = ops.stack(position_ids, axis=-1)
+ position_ids = ops.reshape(position_ids, (-1, 3))
+
+ t_inds, h_inds, w_inds = ops.unbind(position_ids, dim=-1)
+ pos_embed_t = self.pos_embedding_table_t(t_inds)
+ pos_embed_h = self.pos_embedding_table_h(h_inds)
+ pos_embed_w = self.pos_embedding_table_w(w_inds)
+ pos_embed = pos_embed_t + pos_embed_h + pos_embed_w
+ pos_embed = ops.unsqueeze(pos_embed, 0)
+ return pos_embed
+
+ def unpatchify(self, hidden_states: Tensor, t: int, h: int, w: int) -> Tensor:
+ """
+ hidden_states: (N, T, patch_size[0] * patch_size[1] * patch_size[2] * C)
+ """
+ bs = hidden_states.shape[0]
+ c = self.out_channels
+ p0, p1, p2 = self.patch_size[0], self.patch_size[1], self.patch_size[2]
+ nt, nh, nw = t // p0, h // p1, w // p2
+
+ hidden_states = ops.reshape(hidden_states, (bs, nt, nh, nw, p0, p1, p2, c))
+ # bs, nt, p0, c, nh, p1, nw, p2, c
+ hidden_states = mint.permute(hidden_states, (0, 1, 4, 7, 2, 5, 3, 6))
+ output = ops.reshape(hidden_states, (bs, nt * p0, c, nh * p1, nw * p2))
+ return output
+
+ def construct(
+ self, latent_embedding: Tensor, timestep: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor
+ ) -> Tensor:
+ """
+ latent_embedding: (N, T, C, H, W) tensor of inputs (latent representations of video)
+ timestep: (N,) tensor to indicate a denoising step
+ ul2_emb: (N, L1, 4096) UL2 text embeddings
+ metaclip_emb: (N, L2, 1280) MetaCLIP text embeddings
+ byt5_emb: (N, L3, 1472) ByT5 text embeddings
+ """
+ _, t, _, h, w = latent_embedding.shape
+
+ # create position embedding to be shared across the decoder layers
+ position_embedding = self.learnable_position_embedding(latent_embedding)
+ position_embedding = position_embedding.to(latent_embedding.dtype)
+
+ # patchify and embed latent in transformer hidden dim.
+ latent_embedding = self.latent_embedder(latent_embedding)
+
+ # 6.1.2 shared timestep embedding & modulation. It does not mention the detail structure, we follow PixArt-Alpha here
+ timestep_embedding = self.timestep_embedder(timestep)
+ modulation_parameters = self.adaLN_modulation(timestep_embedding)
+
+ # 3.1.4 text embedding
+ text_embedding = self.text_projector(ul2_emb, metaclip_emb, byt5_emb)
+
+ # sequence parallel start
+ latent_embedding = self.split_forward_gather_backward(latent_embedding)
+ position_embedding = self.split_forward_gather_backward(position_embedding)
+
+ # main blocks
+ hidden_states = latent_embedding
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(hidden_states, text_embedding, modulation_parameters, position_embedding)
+
+ # sequence parallel end
+ hidden_states = self.gather_forward_split_backward(hidden_states)
+
+ # final block
+ hidden_states = self.final_layer(hidden_states, timestep_embedding)
+
+ # unpatchify
+ output = self.unpatchify(hidden_states, t, h, w)
+ return output
+
+ def construct_with_cfg(
+ self,
+ latent_embedding: Tensor,
+ timestep: Tensor,
+ text_embedding: Tensor,
+ cfg_scale: Union[Tensor, float] = 7.5,
+ ) -> Tensor:
+ """
+ latent_embedding: (2N, T, C, H, W) tensor of inputs (latent representations of video)
+ timestep: (2N,) tensor to indicate denoising step
+ text_embedding: (2N, L, C') tensor of the text embedding
+ cfg_scale: CFG scale
+ """
+ model_out = self(latent_embedding, timestep, text_embedding)
+ cond_model_out, uncond_model_out = mint.chunk(model_out, 2, dim=0)
+ model_out = uncond_model_out + cfg_scale * (cond_model_out - uncond_model_out)
+ model_out = mint.tile(model_out, (2, 1, 1, 1, 1))
+ return model_out
+
+
+def llama3_1B(from_pretrained=None, **kwargs):
+ model = LlamaModel(
+ attention_bias=False,
+ attention_dropout=0.0,
+ hidden_act="silu",
+ hidden_size=1536,
+ initializer_range=0.02,
+ intermediate_size=4096,
+ num_attention_heads=16,
+ num_hidden_layers=24,
+ num_key_value_heads=16,
+ rms_norm_eps=1e-05,
+ **kwargs,
+ )
+ if from_pretrained is not None:
+ load_checkpoint(from_pretrained, model)
+ return model
+
+
+def llama3_5B(from_pretrained=None, **kwargs):
+ model = LlamaModel(
+ attention_bias=False,
+ attention_dropout=0.0,
+ hidden_act="silu",
+ hidden_size=3072,
+ initializer_range=0.02,
+ intermediate_size=8192,
+ num_attention_heads=24,
+ num_hidden_layers=32,
+ num_key_value_heads=24,
+ rms_norm_eps=1e-05,
+ **kwargs,
+ )
+ if from_pretrained is not None:
+ load_checkpoint(from_pretrained, model)
+ return model
+
+
+def llama3_30B(from_pretrained=None, **kwargs):
+ model = LlamaModel(
+ attention_bias=False,
+ attention_dropout=0.0,
+ hidden_act="silu",
+ hidden_size=6144,
+ initializer_range=0.02,
+ intermediate_size=16384,
+ num_attention_heads=48,
+ num_hidden_layers=48,
+ num_key_value_heads=48,
+ rms_norm_eps=1e-05,
+ **kwargs,
+ )
+ if from_pretrained is not None:
+ load_checkpoint(from_pretrained, model)
+ return model
diff --git a/examples/moviegen/mg/models/tae/__init__.py b/examples/moviegen/mg/models/tae/__init__.py
new file mode 100644
index 0000000000..75d32e29fc
--- /dev/null
+++ b/examples/moviegen/mg/models/tae/__init__.py
@@ -0,0 +1 @@
+from .tae import TemporalAutoencoder
diff --git a/examples/moviegen/mg/models/tae/tae.py b/examples/moviegen/mg/models/tae/tae.py
index cbc345c68a..0c59f4a1ac 100644
--- a/examples/moviegen/mg/models/tae/tae.py
+++ b/examples/moviegen/mg/models/tae/tae.py
@@ -1,7 +1,9 @@
import math
+from typing import Literal, Optional, Tuple
-import mindspore as ms
-from mindspore import nn, ops
+from mindspore import Tensor
+from mindspore import dtype as mstype
+from mindspore import load_checkpoint, load_param_into_net, mint, nn, ops
from .modules import Conv2_5d, Decoder, Encoder
@@ -54,7 +56,7 @@ class TemporalAutoencoder(nn.Cell):
def __init__(
self,
config: dict = TAE_CONFIG,
- pretrained: str = None,
+ pretrained: Optional[str] = None,
use_recompute: bool = False,
sample_deterministic: bool = False,
use_tile: bool = False,
@@ -62,8 +64,14 @@ def __init__(
encode_overlap: int = 0,
decode_tile: int = 32,
decode_overlap: int = 16,
+ dtype: Literal["fp32", "fp16", "bf16"] = "fp32",
):
super().__init__()
+ self.out_channels = config["z_channels"]
+ self.scale_factor = config["scaling_factor"]
+ self.shift_factor = config["shift_factor"]
+ # not used yet, just for CLI initialization convenience
+ self._dtype = {"fp32": mstype.float32, "fp16": mstype.float16, "bf16": mstype.bfloat16}[dtype]
# encoder
self.encoder = Encoder(**config)
@@ -83,7 +91,7 @@ def __init__(
self.exp = ops.Exp()
self.stdnormal = ops.StandardNormal()
- self.split = ms.ops.split
+ self.split = ops.split
self.sample_deterministic = sample_deterministic
self.discard_spurious_frames = True
@@ -110,9 +118,13 @@ def __init__(
self.recompute(self.encoder)
self.recompute(self.decoder)
- if pretrained is not None:
+ if pretrained:
self.load_pretrained(pretrained)
+ @property
+ def dtype(self):
+ return self._dtype
+
def recompute(self, b):
if not b._has_config_recompute:
b.recompute()
@@ -140,7 +152,7 @@ def sample(self, mean, logvar):
return z
- def encode(self, x: ms.Tensor) -> ms.Tensor:
+ def encode(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
r"""
Encode a batch of videos into latents
@@ -152,7 +164,12 @@ def encode(self, x: ms.Tensor) -> ms.Tensor:
posterior_mean (Tensor): mean of latent distribution
posterior_logvar (Tensor): logvar of latent distribution
"""
+ if self.use_tile:
+ return self.encode_with_tile(x)
+ else:
+ return self._encode_no_tile(x)
+ def _encode_no_tile(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
posterior_mean, posterior_logvar = self._encode(x)
if self.sample_deterministic:
return posterior_mean
@@ -160,19 +177,25 @@ def encode(self, x: ms.Tensor) -> ms.Tensor:
return z, posterior_mean, posterior_logvar
- def decode(self, z: ms.Tensor, target_num_frames: int = None) -> ms.Tensor:
+ def decode(self, z: Tensor, target_num_frames: int = None) -> Tensor:
r"""
Decode a batch of latents to videos
Args:
- x (Tensor): input latent tensor of shape (b z t' h' w')
- target_num_frames (int): target number of frames for output, if None, all the decoded frames will be reserved. \
- Otherwise, the previous this number of frames will be reserved.
+ z (Tensor): input latent tensor of shape (b z t' h' w')
+ target_num_frames (int): target number of frames for output.
+ If None, all the decoded frames will be reserved.
+ Otherwise, the previous this number of frames will be reserved.
Returns:
z (Tensor): the decoded videos of shape (b c t h w)
"""
+ if self.use_tile:
+ return self.decode_with_tile(z, target_num_frames)
+ else:
+ return self._decode_no_tile(z, target_num_frames)
+ def _decode_no_tile(self, z: Tensor, target_num_frames: int = None) -> Tensor:
if self.use_post_quant_conv:
z = self.post_quant_conv(z)
dec = self.decoder(z)
@@ -182,7 +205,7 @@ def decode(self, z: ms.Tensor, target_num_frames: int = None) -> ms.Tensor:
return dec
- def encode_with_tile(self, x: ms.Tensor) -> ms.Tensor:
+ def encode_with_tile(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
r"""
Encode a batch of videos into latents with tiling
@@ -197,23 +220,25 @@ def encode_with_tile(self, x: ms.Tensor) -> ms.Tensor:
tf = self.encode_tile
- z_out, mean, logvar = self.encode(x[:, :, :tf])
+ z_out, mean, logvar = self._encode_no_tile(x[:, :, :tf])
for i in range(tf, x.shape[2], tf):
- z_cur, mean, logvar = self.encode(x[:, :, i : i + tf])
- z_out = ops.cat((z_out, z_cur), axis=2)
+ z_cur, mean_cur, logvar_cur = self._encode_no_tile(x[:, :, i : i + tf])
+ z_out = mint.cat((z_out, z_cur), dim=2)
+ mean = mint.cat((mean, mean_cur), dim=2)
+ logvar = mint.cat((logvar, logvar_cur), dim=2)
- # TODO: merge mean, logvar for different slices for training tae with tile
return z_out, mean, logvar
- def decode_with_tile(self, z: ms.Tensor, target_num_frames: int = None) -> ms.Tensor:
+ def decode_with_tile(self, z: Tensor, target_num_frames: int = None) -> Tensor:
r"""
Decode a batch of latents to videos with tiling
Args:
x (Tensor): input latent tensor of shape (b z t' h' w')
- target_num_frames (int): target number of frames for output, if None, all the decoded frames will be reserved. \
- Otherwise, the previous this number of frames will be reserved.
+ target_num_frames (int): target number of frames for output.
+ If None, all the decoded frames will be reserved.
+ Otherwise, the previous this number of frames will be reserved.
Returns:
z (Tensor): the decoded videos of shape (b c t h w)
@@ -228,13 +253,13 @@ def decode_with_tile(self, z: ms.Tensor, target_num_frames: int = None) -> ms.Te
num_slices += 1
# ms graph mode requires an init x_out
- x_out = self.decode(z[:, :, :tl])
+ x_out = self._decode_no_tile(z[:, :, :tl])
visited = tl
i = stride # start position
while visited < in_len:
- x_cur = self.decode(z[:, :, i : i + tl])
- x_out = ops.cat((x_out, x_cur), axis=2)
+ x_cur = self._decode_no_tile(z[:, :, i : i + tl])
+ x_out = mint.cat((x_out, x_cur), dim=2)
visited = i + tl
i += stride
@@ -248,7 +273,7 @@ def decode_with_tile(self, z: ms.Tensor, target_num_frames: int = None) -> ms.Te
return x_out
- def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16):
+ def blend_slices(self, x: Tensor, slice_len=32, overlap_len=16):
"""
Blend decoded latent slices, used with decode_with_tile
@@ -260,7 +285,7 @@ def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16):
Note that the length of the last slice can be shorter than slice_len.
Returns:
- ms.Tensor
+ Tensor
"""
B, C, in_len, H, W = x.shape
@@ -271,8 +296,8 @@ def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16):
last_slice_len = in_len - (num_slices - 1) * slice_len
out_len += last_slice_len - overlap_len
- out_tensor = ops.zeros((B, C, out_len, H, W), ms.float32)
- out_cnt = ops.zeros((B, C, out_len, H, W), ms.float32)
+ out_tensor = mint.zeros((B, C, out_len, H, W), dtype=mstype.float32)
+ out_cnt = mint.zeros((B, C, out_len, H, W), dtype=mstype.float32)
for i in range(num_slices):
# get the slice form the concatnated latent
@@ -288,7 +313,7 @@ def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16):
return out_tensor
- def construct(self, x: ms.Tensor) -> ms.Tensor:
+ def construct(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""
Video reconstruction
@@ -308,10 +333,7 @@ def construct(self, x: ms.Tensor) -> ms.Tensor:
posterior_mean, posterior_logvar = self._encode(x)
z = self.sample(posterior_mean, posterior_logvar)
- if self.use_tile:
- recons = self.decode_with_tile(z)
- else:
- recons = self.decode(z)
+ recons = self.decode(z)
if self.discard_spurious_frames and (recons.shape[-3] != x.shape[-3]):
recons = recons[:, :, : x.shape[-3], :, :]
@@ -329,7 +351,7 @@ def load_pretrained(self, ckpt_path: str):
state_dict[key] = ckpt.get_tensor(key)
raise NotImplementedError
else:
- param_dict = ms.load_checkpoint(ckpt_path)
+ param_dict = load_checkpoint(ckpt_path)
# remove the added prefix in the trained checkpoint
pnames = list(param_dict.keys())
@@ -337,10 +359,15 @@ def load_pretrained(self, ckpt_path: str):
new_pn = pn.replace("autoencoder.", "").replace("_backbone.", "")
param_dict[new_pn] = param_dict.pop(pn)
- param_not_load, ckpt_not_load = ms.load_param_into_net(self, param_dict, strict_load=True)
+ param_not_load, ckpt_not_load = load_param_into_net(self, param_dict, strict_load=True)
if param_not_load or ckpt_not_load:
print(f"{param_not_load} in network is not loaded")
print(f"{ckpt_not_load} in checkpoint is not loaded!")
print("TAE checkpoint loaded")
+
+ @staticmethod
+ def get_latent_size(input_size: Tuple[int, int, int]) -> Tuple[int, int, int]:
+ # FIXME: validate
+ return max(input_size[0] // 8, 1), input_size[1] // 8, input_size[2] // 8
diff --git a/examples/moviegen/mg/models/text_encoders/__init__.py b/examples/moviegen/mg/models/text_encoders/__init__.py
new file mode 100644
index 0000000000..c26604e0d0
--- /dev/null
+++ b/examples/moviegen/mg/models/text_encoders/__init__.py
@@ -0,0 +1 @@
+from .text_projector import TextProjector
diff --git a/examples/moviegen/mg/models/text_encoders/text_projector.py b/examples/moviegen/mg/models/text_encoders/text_projector.py
new file mode 100644
index 0000000000..edcebd848b
--- /dev/null
+++ b/examples/moviegen/mg/models/text_encoders/text_projector.py
@@ -0,0 +1,56 @@
+from typing import Type
+
+import mindspore as ms
+from mindspore import Tensor, mint, nn
+
+from mindone.models.utils import normal_, zeros_
+
+
+class TextProjector(nn.Cell):
+ def __init__(
+ self,
+ ul2_in_features: int = 4096,
+ metaclip_in_features: int = 1280,
+ byt5_in_features: int = 1472,
+ out_features: int = 6144,
+ layer_norm: Type[nn.Cell] = mint.nn.LayerNorm,
+ norm_eps: float = 1e-5,
+ initializer_range: float = 0.02,
+ post_init_weight: bool = True,
+ dtype: ms.Type = ms.float32,
+ ):
+ super().__init__()
+ # split layers for easier exclusion from weight decay
+ self.ul2_linear = mint.nn.Linear(ul2_in_features, out_features, bias=False, dtype=dtype)
+ self.ul2_layernorm = layer_norm((out_features,), eps=norm_eps)
+
+ self.metaclip_linear = mint.nn.Linear(metaclip_in_features, out_features, bias=False, dtype=dtype)
+ self.metaclip_layernorm = layer_norm((out_features,), eps=norm_eps)
+
+ self.byt5_linear = mint.nn.Linear(byt5_in_features, out_features, bias=False, dtype=dtype)
+ self.byt5_layernorm = layer_norm((out_features,), eps=norm_eps)
+
+ self.initializer_range = initializer_range
+
+ # post-init
+ if post_init_weight:
+ self.initializer_range = initializer_range
+ self.init_weights()
+
+ def init_weights(self):
+ std = self.initializer_range
+
+ def _init_weights(module):
+ if isinstance(module, mint.nn.Linear):
+ normal_(module.weight, mean=0.0, std=std)
+ if module.bias is not None:
+ zeros_(module.weight)
+
+ self.apply(_init_weights)
+
+ def construct(self, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor) -> Tensor:
+ ul2_hidden_states = self.ul2_layernorm(self.ul2_linear(ul2_emb))
+ metaclip_hidden_states = self.metaclip_layernorm(self.metaclip_linear(metaclip_emb))
+ byt5_hidden_states = self.byt5_layernorm(self.byt5_linear(byt5_emb))
+
+ return mint.cat((ul2_hidden_states, metaclip_hidden_states, byt5_hidden_states), dim=1)
diff --git a/examples/moviegen/mg/pipelines/__init__.py b/examples/moviegen/mg/pipelines/__init__.py
new file mode 100644
index 0000000000..93ba177d16
--- /dev/null
+++ b/examples/moviegen/mg/pipelines/__init__.py
@@ -0,0 +1,2 @@
+from .infer_pipeline import InferPipeline
+from .train_pipeline import DiffusionWithLoss
diff --git a/examples/moviegen/mg/pipelines/infer_pipeline.py b/examples/moviegen/mg/pipelines/infer_pipeline.py
new file mode 100644
index 0000000000..6e0e5eca76
--- /dev/null
+++ b/examples/moviegen/mg/pipelines/infer_pipeline.py
@@ -0,0 +1,93 @@
+from typing import Literal, Optional, Tuple, Union
+
+import numpy as np
+
+import mindspore as ms
+from mindspore import Tensor, mint, ops
+
+from ..models import LlamaModel, TemporalAutoencoder
+from ..schedulers.rectified_flow import RFLOW
+
+__all__ = ["InferPipeline"]
+
+
+class InferPipeline:
+ """An Inference pipeline for Movie Gen.
+
+ Args:
+ model (LlamaModel): A noise prediction model to denoise the encoded image latents.
+ tae (TemporalAutoencoder, optional): Temporal Auto-Encoder (TAE) Model to encode and decode images or videos to
+ and from latent representations.
+ scale_factor (float): scale_factor for TAE.
+ guidance_scale (float): A higher guidance scale value for noise rescale.
+ num_sampling_steps: (int): The number of denoising steps.
+ """
+
+ def __init__(
+ self,
+ model: LlamaModel,
+ tae: Optional[TemporalAutoencoder] = None,
+ latent_size: Tuple[int, int, int] = (1, 64, 64),
+ guidance_scale: float = 1.0,
+ num_sampling_steps: int = 50,
+ sample_method: Literal["linear", "linear-quadratic"] = "linear",
+ micro_batch_size: Optional[int] = None,
+ ):
+ super().__init__()
+ self.model = model
+ self.tae = tae
+ self.latent_size = latent_size
+ self.micro_batch_size = micro_batch_size
+ self.guidance_rescale = guidance_scale
+ self.use_cfg = guidance_scale > 1.0
+ self.rflow = RFLOW(num_sampling_steps, sample_method=sample_method)
+
+ def tae_decode_video(self, x, num_frames=None):
+ """
+ Args:
+ x: (b t c h w), denoised latent
+ Return:
+ y: (b f H W 3), batch of images, normalized to [0, 1]
+ """
+ x = mint.permute(x, (0, 2, 1, 3, 4)) # FIXME: remove this redundancy
+ x = x / self.tae.scale_factor + self.tae.shift_factor
+ y = self.tae.decode(x, target_num_frames=num_frames)
+ y = ops.clip_by_value((y + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0)
+ # (b 3 t h w) -> (b t h w 3)
+ y = mint.permute(y, (0, 2, 3, 4, 1))
+ return y
+
+ def __call__(
+ self, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor, num_frames: int = None
+ ) -> Tuple[Union[Tensor, None], Tensor]:
+ """
+ args:
+ inputs: dict
+
+ return:
+ images (b H W 3)
+ """
+ z = ms.Tensor(
+ np.random.randn(
+ ul2_emb.shape[0], self.latent_size[0], self.model.in_channels, self.latent_size[1], self.latent_size[2]
+ ).astype(np.float32),
+ dtype=self.model.dtype,
+ )
+ if self.use_cfg:
+ raise NotImplementedError("Condition-free guidance is not supported yet.")
+
+ latents = self.rflow(
+ self.model,
+ z,
+ ul2_emb.to(self.model.dtype),
+ metaclip_emb.to(self.model.dtype),
+ byt5_emb.to(self.model.dtype),
+ ).to(ms.float32)
+
+ if self.tae is not None:
+ # latents: (b t c h w)
+ # out: (b T H W C)
+ images = self.tae_decode_video(latents, num_frames=num_frames)
+ return images, latents
+ else:
+ return None, latents
diff --git a/examples/moviegen/mg/pipelines/train_pipeline.py b/examples/moviegen/mg/pipelines/train_pipeline.py
new file mode 100644
index 0000000000..9f2b5301d2
--- /dev/null
+++ b/examples/moviegen/mg/pipelines/train_pipeline.py
@@ -0,0 +1,71 @@
+from typing import Optional
+
+import mindspore as ms
+from mindspore import Tensor, mint, nn, ops
+
+from ..models import TemporalAutoencoder
+from ..schedulers import RFlowLossWrapper
+from ..utils.model_utils import no_grad
+
+__all__ = ["DiffusionWithLoss"]
+
+
+class DiffusionWithLoss(nn.Cell):
+ def __init__(
+ self,
+ network: RFlowLossWrapper,
+ tae: Optional[TemporalAutoencoder] = None,
+ text_encoder: Optional[nn.Cell] = None,
+ text_emb_cached: bool = True,
+ video_emb_cached: bool = False,
+ ):
+ super().__init__()
+
+ if not text_emb_cached and text_encoder is None:
+ raise ValueError("`text_encoder` must be provided when `text_emb_cached=False`.")
+ if not video_emb_cached and tae is None:
+ raise ValueError("`TAE` must be provided when `video_emb_cached=False`.")
+
+ self.network = network
+ self.tae = tae
+ self.text_encoder = text_encoder
+ self.text_emb_cached = text_emb_cached
+ self.video_emb_cached = video_emb_cached
+
+ if self.tae is not None:
+ for param in self.tae.trainable_params():
+ param.requires_grad = False
+
+ if self.text_encoder is not None:
+ for param in self.text_encoder.trainable_params():
+ param.requires_grad = False
+
+ def get_condition_embeddings(self, text_tokens: Tensor) -> Tensor:
+ if self.text_emb_cached:
+ return text_tokens
+ with no_grad():
+ text_emb = ops.stop_gradient(self.text_encoder(text_tokens))
+ return text_emb
+
+ def get_latents(self, video_tokens: Tensor) -> Tensor:
+ if self.video_emb_cached:
+ return video_tokens
+ with no_grad():
+ # (b c f h w) shape is expected. FIXME: remove this redundancy
+ video_tokens = mint.permute(video_tokens, (0, 2, 1, 3, 4))
+ video_emb = ops.stop_gradient(self.tae.encode(video_tokens)[0]).to(ms.float32)
+ video_emb = (video_emb - self.tae.shift_factor) * self.tae.scale_factor
+ video_emb = mint.permute(video_emb, (0, 2, 1, 3, 4)) # FIXME
+ return video_emb
+
+ def set_train(self, mode=True):
+ # Set the diffusion model only to train or eval mode
+ self.network.set_train(mode)
+
+ def construct(self, video_tokens: Tensor, ul2_tokens: Tensor, byt5_tokens: Tensor) -> Tensor:
+ latent_embedding = self.get_latents(video_tokens)
+ ul2_emb = self.get_condition_embeddings(ul2_tokens)
+ byt5_emb = self.get_condition_embeddings(byt5_tokens)
+ # FIXME: add metaclip
+ metaclip_emb = mint.ones((byt5_emb.shape[0], 300, 1280), dtype=byt5_emb.dtype)
+ return self.network(latent_embedding, ul2_emb, metaclip_emb, byt5_emb)
diff --git a/examples/moviegen/mg/schedulers/__init__.py b/examples/moviegen/mg/schedulers/__init__.py
new file mode 100644
index 0000000000..d030f82972
--- /dev/null
+++ b/examples/moviegen/mg/schedulers/__init__.py
@@ -0,0 +1 @@
+from .rectified_flow import *
diff --git a/examples/moviegen/mg/schedulers/rectified_flow.py b/examples/moviegen/mg/schedulers/rectified_flow.py
new file mode 100644
index 0000000000..84e6fbf29d
--- /dev/null
+++ b/examples/moviegen/mg/schedulers/rectified_flow.py
@@ -0,0 +1,186 @@
+import logging
+from math import ceil
+from typing import Literal, Optional, Tuple
+
+import numpy as np
+from tqdm import tqdm
+
+from mindspore import Tensor
+from mindspore import dtype as mstype
+from mindspore import mint, nn, ops
+from mindspore.communication import get_rank
+
+from ..acceleration import get_sequence_parallel_group
+from ..models import LlamaModel
+
+logger = logging.getLogger(__name__)
+
+__all__ = ["RFLOW", "RFlowLossWrapper", "RFlowEvalLoss"]
+
+
+class LogisticNormal(nn.Cell):
+ def __init__(self, loc: float = 0.0, scale: float = 1.0):
+ super().__init__()
+ self.stdnormal = ops.StandardNormal()
+ self.mean = loc
+ self.std = scale
+ self._min = Tensor(np.finfo(np.float32).tiny, dtype=mstype.float32)
+ self._max = Tensor(1.0 - np.finfo(np.float32).eps, dtype=mstype.float32)
+
+ def construct(self, shape: Tuple[int]) -> Tensor:
+ x = self.mean + self.std * self.stdnormal(shape)
+ return ops.clamp(ops.sigmoid(x), self._min, self._max)
+
+
+class RFLOW:
+ def __init__(
+ self,
+ num_sampling_steps: int = 50,
+ num_timesteps: int = 1000,
+ sample_method: Literal["linear", "linear-quadratic"] = "linear",
+ ) -> None:
+ self.num_sampling_steps = num_sampling_steps
+ self.num_timesteps = num_timesteps
+ self.sample_method = sample_method
+
+ def __call__(self, model: nn.Cell, x: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor) -> Tensor:
+ """
+ x: (N, T, C, H, W) tensor of inputs (latent representations of video)
+ text_embedding: (N, L, C') tensor of the text embedding
+ """
+ # prepare timesteps
+ if self.sample_method == "linear":
+ timesteps = (1.0 - np.arange(self.num_sampling_steps) / self.num_sampling_steps) * self.num_timesteps
+ else:
+ first_half = ceil(self.num_sampling_steps / 2)
+ second_half = self.num_sampling_steps - first_half # in the case of an odd number of sampling steps
+ linear = self.num_timesteps - np.arange(first_half)
+ quadratic = (np.arange(1, second_half + 1) ** 2) / ((second_half + 1) ** 2)
+ quadratic = (self.num_timesteps - (first_half - 1)) * quadratic + (first_half - 1) # scale and shift
+ quadratic = self.num_timesteps - quadratic
+ timesteps = np.concatenate([linear, quadratic])
+
+ timesteps = np.tile(timesteps[..., None], (1, x.shape[0]))
+ timesteps = Tensor(timesteps, dtype=model.dtype) # FIXME: avoid calculations on tensors outside `construct`
+
+ for i, timestep in tqdm(enumerate(timesteps), total=self.num_sampling_steps):
+ pred = model(x, timestep, ul2_emb, metaclip_emb, byt5_emb)
+
+ # update z
+ dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i]
+ dt = dt / self.num_timesteps
+ x = x + pred * dt[:, None, None, None, None]
+
+ return x
+
+
+class RFlowLossWrapper(nn.Cell):
+ """Wrapper for calculating the training loss"""
+
+ def __init__(
+ self,
+ model: LlamaModel,
+ num_timesteps: int = 1000,
+ sample_method: Literal["discrete-uniform", "uniform", "logit-normal"] = "logit-normal",
+ loc: float = 0.0,
+ scale: float = 1.0,
+ eps: float = 1e-5,
+ ) -> None:
+ super().__init__(auto_prefix=False)
+ self.num_timesteps = num_timesteps
+ self.eps = eps
+
+ if sample_method == "discrete-uniform":
+ self._sample_func = self._discrete_sample
+ elif sample_method == "uniform":
+ self._sample_func = self._uniform_sample
+ elif sample_method == "logit-normal":
+ self.distribution = LogisticNormal(loc=loc, scale=scale)
+ self._sample_func = self._logit_normal_sample
+ else:
+ raise ValueError(f"Unknown sample method: {sample_method}")
+
+ self.model = model
+ self.criteria = nn.MSELoss()
+
+ self.broadcast = None
+ if (sp_group := get_sequence_parallel_group()) is not None:
+ logging.info(
+ f"Broadcasting all random variables from rank (0) to current rank ({get_rank(sp_group)}) in group `{sp_group}`."
+ )
+ self.broadcast = ops.Broadcast(0, group=sp_group)
+
+ def _discrete_sample(self, size: int) -> Tensor:
+ return ops.randint(0, self.num_timesteps, (size,), dtype=mstype.int64)
+
+ def _uniform_sample(self, size: int) -> Tensor:
+ return mint.rand((size,), dtype=mstype.float32) * self.num_timesteps
+
+ def _logit_normal_sample(self, size: int) -> Tensor:
+ return self.distribution((size,)) * self.num_timesteps
+
+ def _broadcast(self, x: Tensor) -> Tensor:
+ if self.broadcast is None:
+ return x
+ return self.broadcast((x,))[0]
+
+ def construct(
+ self, x: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor, timestep: Optional[Tensor] = None
+ ) -> Tensor:
+ """
+ Calculate the training loss for the corresponding timestep.
+ x: (N, T, C, H, W) tensor of inputs (latent representations of video)
+ ul2_emb: (N, L1, 4096) UL2 text embeddings
+ metaclip_emb: (N, L2, 1280) MetaCLIP text embeddings
+ byt5_emb: (N, L3, 1472) ByT5 text embeddings
+ timestep: (N,) tensor to indicate a denoising step
+ """
+ x = x.to(mstype.float32)
+
+ if timestep is None:
+ timestep = self._broadcast(self._sample_func(x.shape[0]))
+
+ noise = self._broadcast(ops.randn_like(x))
+ x_t = self.add_noise(x, noise, timestep)
+
+ model_output = self.model(
+ x_t.to(self.model.dtype),
+ timestep,
+ ul2_emb.to(self.model.dtype),
+ metaclip_emb.to(self.model.dtype),
+ byt5_emb.to(self.model.dtype),
+ ).to(mstype.float32)
+ v_t = x - (1 - self.eps) * noise
+
+ # 3.1.2 Eqa (2)
+ loss = self.criteria(model_output, v_t)
+ return loss
+
+ def add_noise(self, x: Tensor, noise: Tensor, timesteps: Tensor) -> Tensor:
+ """
+ x: (N, T, C, H, W) tensor of ground truth
+ noise: (N, T, C, H, W) tensor of white noise
+ timesteps: (N,) tensor of timestamps with range [0, num_timesteps)
+ """
+ timesteps = 1 - timesteps.to(mstype.float32) / self.num_timesteps
+ timesteps = timesteps[:, None, None, None, None]
+
+ # 3.1.2 First Eqa.
+ return timesteps * x + (1 - (1 - self.eps) * timesteps) * noise # TODO: check for zero SNR
+
+
+class RFlowEvalLoss(nn.Cell):
+ def __init__(self, network: RFlowLossWrapper, num_sampling_steps: int = 10):
+ super().__init__()
+ self.network = network
+ self.timesteps = Tensor(
+ np.linspace(0, network.num_timesteps, num_sampling_steps + 2)[1:-1].reshape(-1, 1), dtype=mstype.float32
+ )
+
+ def construct(self, x: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor, **kwargs) -> Tensor:
+ loss = Tensor(0, dtype=mstype.float32)
+ timesteps = mint.tile(self.timesteps, (1, x.shape[0]))
+ for t in timesteps:
+ loss += self.network(x, ul2_emb, metaclip_emb, byt5_emb, t)
+
+ return loss / len(self.timesteps)
diff --git a/examples/moviegen/mg/utils/__init__.py b/examples/moviegen/mg/utils/__init__.py
new file mode 100644
index 0000000000..73fb65477b
--- /dev/null
+++ b/examples/moviegen/mg/utils/__init__.py
@@ -0,0 +1,4 @@
+from .callbacks import *
+from .ema import *
+from .model_utils import *
+from .utils import *
diff --git a/examples/moviegen/mg/utils/callbacks.py b/examples/moviegen/mg/utils/callbacks.py
new file mode 100644
index 0000000000..66f5ff6ca7
--- /dev/null
+++ b/examples/moviegen/mg/utils/callbacks.py
@@ -0,0 +1,197 @@
+import logging
+import os
+import time
+from typing import List, Literal, Optional, Union
+
+import numpy as np
+import pandas as pd
+
+from mindspore import Callback, Parameter, ReduceLROnPlateau, RunContext, Tensor
+from mindspore import dtype as mstype
+from mindspore import mint, nn, ops
+from mindspore.communication import GlobalComm, get_group_size
+from mindspore.dataset import BatchDataset, BucketBatchByLengthDataset, GeneratorDataset
+from mindspore.ops import functional as F
+
+from mindone.trainers.ema import EMA
+
+__all__ = ["ValidationCallback", "PerfRecorderCallback", "ReduceLROnPlateauByStep"]
+
+_logger = logging.getLogger(__name__)
+
+
+class ValidationCallback(Callback):
+ """
+ A callback for performing validation during training on a per-step basis.
+
+ Args:
+ network (nn.Cell): The neural network model to be validated.
+ dataset (BatchDataset, BucketBatchByLengthDataset, GeneratorDataset): The dataset to use for validation.
+ alpha_smooth (float, optional): The smoothing factor for the loss. Defaults to 0.01.
+ valid_frequency (int, optional): The frequency of validation in terms of training steps.
+ Defaults to 100.
+ ema (Optional[EMA], optional): An Exponential Moving Average object for the model weights.
+ If provided, it will be used during validation. Defaults to None.
+
+ Example:
+ >>> model = MyModel()
+ >>> val_dataset = MyValidationDataset()
+ >>> val_callback = ValidationCallback(model, val_dataset, valid_frequency=500)
+ >>> model.train(num_epochs, train_dataset, callbacks=[val_callback])
+ """
+
+ def __init__(
+ self,
+ network: nn.Cell,
+ dataset: Union[BatchDataset, BucketBatchByLengthDataset, GeneratorDataset],
+ alpha_smooth: float = 0.01,
+ valid_frequency: int = 100,
+ ema: Optional[EMA] = None,
+ ):
+ super().__init__()
+ self.network = network
+ self.dataset = dataset
+ self.alpha_smooth = alpha_smooth
+ self.valid_frequency = valid_frequency
+ self.ema = ema
+ self.reduce, self.rank_size = None, 1
+ if GlobalComm.INITED:
+ self.reduce = ops.AllReduce(op=ops.ReduceOp.SUM)
+ self.rank_size = get_group_size()
+ self.data = pd.Series(dtype=np.float32)
+
+ def on_train_step_end(self, run_context: RunContext):
+ cb_params = run_context.original_args()
+ cb_params.eval_results = {} # Erase previous validation results
+ cur_step = cb_params.cur_step_num
+
+ if cur_step % self.valid_frequency == 0:
+ if self.ema is not None:
+ self.ema.swap_before_eval()
+ self.network.set_train(False)
+
+ loss = 0
+ for data in self.dataset.create_tuple_iterator(num_epochs=1):
+ loss += self.network(*data)
+ loss = loss / self.dataset.get_dataset_size()
+ if self.reduce is not None:
+ loss = self.reduce(loss)
+ loss = loss.item() / self.rank_size
+
+ self.data = pd.concat([self.data, pd.Series(loss)], ignore_index=True)
+ loss_smoothed = self.data.ewm(alpha=self.alpha_smooth).mean().iloc[-1]
+
+ cb_params.eval_results = {"eval_loss": loss, "eval_loss_smoothed": loss_smoothed}
+ _logger.info(f"Step: {cur_step}, Validation Loss: {loss}.")
+
+ self.network.set_train(True)
+ if self.ema is not None:
+ self.ema.swap_after_eval()
+
+
+class PerfRecorderCallback(Callback):
+ """
+ Improved version of `mindone.trainers.recorder.PerfRecorder` that tracks validation metrics as well.
+ Used here first for testing.
+ """
+
+ def __init__(
+ self,
+ save_dir: str,
+ file_name: str = "result.log",
+ metric_names: List[str] = None,
+ separator: str = "\t",
+ ):
+ super().__init__()
+ self._sep = separator
+ self._metrics = metric_names or []
+
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir, exist_ok=True)
+ self._log_file = os.path.join(save_dir, file_name)
+
+ header = separator.join([f"{'step':<7}", f"{'loss':<10}", "train_time(s)"] + self._metrics)
+ with open(self._log_file, "w", encoding="utf-8") as fp:
+ fp.write(header + "\n")
+
+ def on_train_step_begin(self, run_context: RunContext):
+ self._step_time = time.perf_counter()
+
+ def on_train_step_end(self, run_context: RunContext):
+ step_time = time.perf_counter() - self._step_time
+ cb_params = run_context.original_args()
+ cur_step = cb_params.cur_step_num
+ loss = cb_params.net_outputs
+ loss = loss[0].asnumpy() if isinstance(loss, tuple) else np.mean(loss.asnumpy())
+ eval_loss = cb_params.get("eval_results", [])
+ metrics = (self._sep + self._sep.join([f"{eval_loss[m]:.6f}" for m in self._metrics])) if eval_loss else ""
+
+ with open(self._log_file, "a", encoding="utf-8") as fp:
+ fp.write(
+ self._sep.join([f"{cur_step:<7}", f"{loss.item():<10.6f}", f"{step_time:<13.3f}"]) + metrics + "\n"
+ )
+
+
+class ReduceLROnPlateauByStep(ReduceLROnPlateau):
+ """
+ Extends ReduceLROnPlateau to reduce the learning rate at the end of a step and incorporates loss smoothing.
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ monitor: str = "eval_loss_smoothed",
+ factor: float = 0.1,
+ patience: int = 10,
+ mode: Literal["auto", "min", "max"] = "auto",
+ min_delta: float = 1e-4,
+ cooldown: int = 0,
+ min_lr: float = 0.0,
+ ):
+ super().__init__(monitor, factor, patience, mode=mode, min_delta=min_delta, cooldown=cooldown, min_lr=min_lr)
+ self.optimizer = optimizer
+ self.min_lr = Tensor(self.min_lr, dtype=mstype.float32)
+
+ def on_train_step_end(self, run_context):
+ """
+ monitors the training process and if no improvement is seen for a 'patience' number
+ of epochs, the learning rate is reduced.
+
+ Copy of the original `on_train_step_end()` with changes to add loss alpha smoothing.
+
+ Args:
+ run_context (RunContext): Context information of the model. For more details,
+ please refer to :class:`mindspore.train.RunContext`.
+ """
+ cb_params = run_context.original_args()
+ cur_step = cb_params.cur_step_num
+ lrs = self.optimizer.learning_rate.learning_rate
+ if not isinstance(lrs, Parameter):
+ raise ValueError("ReduceLROnPlateau does not support dynamic learning rate and group learning rate now.")
+
+ current_monitor_value = cb_params.get("eval_results")
+ if current_monitor_value:
+ current_monitor_value = current_monitor_value[self.monitor]
+
+ if self.cooldown_counter > 0:
+ self.cooldown_counter -= 1
+ self.wait = 0
+
+ if self.is_improvement(current_monitor_value, self.best):
+ self.best = current_monitor_value
+ self.wait = 0
+ elif self.cooldown_counter <= 0:
+ self.wait += 1
+ if self.wait >= self.patience:
+ if lrs[cur_step] > self.min_lr: # FIXME: doesn't hold for future LRs
+ new_lr = lrs * self.factor
+ min_lr = mint.tile(self.min_lr, lrs.shape)
+ new_lr = mint.where(new_lr < min_lr, min_lr, new_lr)
+ F.assign(self.optimizer.learning_rate.learning_rate, new_lr)
+ _logger.info(f"Step {cur_step}: reducing learning rate to {new_lr[cur_step]}.")
+ self.cooldown_counter = self.cooldown
+ self.wait = 0
+
+ def on_train_epoch_end(self, run_context):
+ # Use `on_train_step_end` instead
+ pass
diff --git a/examples/moviegen/mg/utils/ema.py b/examples/moviegen/mg/utils/ema.py
new file mode 100644
index 0000000000..9ff3db69f9
--- /dev/null
+++ b/examples/moviegen/mg/utils/ema.py
@@ -0,0 +1,24 @@
+from mindspore.ops import composite as C
+from mindspore.ops import functional as F
+
+from mindone.trainers.ema import EMA as EMA_
+
+__all__ = ["EMA"]
+
+_ema_op = C.MultitypeFuncGraph("grad_ema_op")
+
+
+@_ema_op.register("Number", "Tensor", "Tensor")
+def _ema_weights(factor, ema_weight, weight):
+ return F.assign(ema_weight, ema_weight * factor + weight * (1 - factor))
+
+
+class EMA(EMA_):
+ def ema_update(self):
+ """Update EMA parameters."""
+ self.updates += 1
+ # update trainable parameters
+ success = self.hyper_map(F.partial(_ema_op, self.ema_decay), self.ema_weight, self.net_weight)
+ self.updates = F.depend(self.updates, success)
+
+ return self.updates
diff --git a/examples/moviegen/mg/utils/model_utils.py b/examples/moviegen/mg/utils/model_utils.py
new file mode 100644
index 0000000000..c7396f49f8
--- /dev/null
+++ b/examples/moviegen/mg/utils/model_utils.py
@@ -0,0 +1,110 @@
+import logging
+from typing import Dict, Literal, Optional, Tuple, Union
+
+from jsonargparse.typing import Path_fr
+from mg.models import LlamaModel, llama3_1B, llama3_5B, llama3_30B
+
+import mindspore as ms
+from mindspore import _no_grad, jit_class, nn
+
+from mindone.trainers.train_step import TrainOneStepWrapper
+from mindone.utils.params import load_param_into_net_with_filter
+
+__all__ = ["MODEL_DTYPE", "no_grad", "init_model", "resume_train_net"]
+
+logger = logging.getLogger(__name__)
+
+MODEL_SPEC = {"llama-1B": llama3_1B, "llama-5B": llama3_5B, "llama-30B": llama3_30B}
+
+MODEL_DTYPE = {
+ "fp32": ms.float32,
+ "fp16": ms.float16,
+ "bf16": ms.bfloat16,
+}
+
+
+def load_ckpt_params(model: nn.Cell, ckpt: Union[str, Dict]) -> None:
+ if isinstance(ckpt, str):
+ logger.info(f"Loading {ckpt} params into network...")
+ param_dict = ms.load_checkpoint(ckpt)
+ param_dict = {k.replace("network.model.", ""): v for k, v in param_dict.items()}
+ else:
+ param_dict = ckpt
+
+ param_not_load, ckpt_not_load = ms.load_param_into_net(model, param_dict)
+ if param_not_load or ckpt_not_load:
+ logger.warning(
+ f"Exist ckpt params not loaded: {ckpt_not_load} (total: {len(ckpt_not_load)}),\n"
+ f"or net params not loaded: {param_not_load} (total: {len(param_not_load)})"
+ )
+
+
+@jit_class
+class no_grad(_no_grad):
+ """
+ A context manager that suppresses gradient memory allocation in PyNative mode.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self._pynative = ms.get_context("mode") == ms.PYNATIVE_MODE
+
+ def __enter__(self):
+ if self._pynative:
+ super().__enter__()
+
+ def __exit__(self, *args):
+ if self._pynative:
+ super().__exit__(*args)
+
+
+def init_model(
+ name: Literal["llama-1B", "llama-5B", "llama-30B"],
+ in_channels: int = 16,
+ pretrained_model_path: Optional[Path_fr] = None,
+ resume: bool = False,
+ enable_flash_attention: bool = True,
+ recompute_every_nth_block: Optional[int] = None,
+ dtype: Literal["fp32", "fp16", "bf16"] = "fp32",
+) -> LlamaModel:
+ attn_implementation = "flash_attention" if enable_flash_attention else "eager"
+ model = MODEL_SPEC[name](
+ in_channels=in_channels,
+ attn_implementation=attn_implementation,
+ recompute_every_nth_block=recompute_every_nth_block,
+ dtype=MODEL_DTYPE[dtype],
+ )
+
+ if resume:
+ logger.info("Resume training checkpoint provided, skipping weight loading.")
+ elif pretrained_model_path:
+ load_ckpt_params(model, pretrained_model_path.absolute)
+ else:
+ logger.info(f"Initialize {name} model randomly.")
+ return model
+
+
+def resume_train_net(
+ train_net: TrainOneStepWrapper, resume_ckpt: Optional[Path_fr] = None
+) -> Tuple[Union[int, None], Union[int, None]]:
+ if resume_ckpt is None:
+ return None, None
+
+ state_dict = ms.load_checkpoint(resume_ckpt)
+ if "epoch_num" not in state_dict or "cur_step" not in state_dict or "loss_scale" not in state_dict:
+ raise ValueError("Resume training checkpoint is invalid. Please check the checkpoint file.")
+
+ start_epoch = state_dict.pop("epoch_num").item()
+ global_step = state_dict.pop("cur_step").item()
+ logger.info(f"Resuming training of network from {resume_ckpt} at global step {global_step}")
+
+ # FIXME: `EvalSaveCallback` renames `scale_sense` to `loss_scale` when saving the resume checkpoint
+ train_net.scale_sense = ms.Parameter(state_dict.pop("loss_scale"), name="scale_sense")
+ param_not_load, ckpt_not_load = load_param_into_net_with_filter(train_net, state_dict, filter=state_dict.keys())
+ if param_not_load or ckpt_not_load:
+ logger.warning(
+ f"Exist ckpt params not loaded: {ckpt_not_load} (total: {len(ckpt_not_load)}),\n"
+ f"or net params not loaded: {param_not_load} (total: {len(param_not_load)})"
+ )
+
+ return start_epoch, global_step
diff --git a/examples/moviegen/mg/utils/utils.py b/examples/moviegen/mg/utils/utils.py
new file mode 100644
index 0000000000..93682df59e
--- /dev/null
+++ b/examples/moviegen/mg/utils/utils.py
@@ -0,0 +1,12 @@
+import numpy as np
+
+from mindspore import Tensor
+from mindspore import dtype as mstype
+
+__all__ = ["to_numpy"]
+
+
+def to_numpy(x: Tensor) -> np.ndarray:
+ if x.dtype == mstype.bfloat16:
+ x = x.astype(mstype.float32)
+ return x.asnumpy()
diff --git a/examples/moviegen/requirements.txt b/examples/moviegen/requirements.txt
new file mode 100644
index 0000000000..3dc0560cd1
--- /dev/null
+++ b/examples/moviegen/requirements.txt
@@ -0,0 +1 @@
+jsonargparse[signatures,omegaconf,urls]>=4.33.0
diff --git a/examples/moviegen/scripts/args_train_tae.py b/examples/moviegen/scripts/args_train_tae.py
index 8b1634613d..5d8e737d51 100644
--- a/examples/moviegen/scripts/args_train_tae.py
+++ b/examples/moviegen/scripts/args_train_tae.py
@@ -1,89 +1,55 @@
-import argparse
import logging
import os
import sys
-import yaml
+from jsonargparse import ActionConfigFile, ArgumentParser
+# TODO: remove in future when mindone is ready for install
__dir__ = os.path.dirname(os.path.abspath(__file__))
mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../"))
-sys.path.insert(0, mindone_lib_path)
+sys.path.append(mindone_lib_path)
+sys.path.append(os.path.join(__dir__, ".."))
-from mg.utils.parser import _check_cfgs_in_parser, str2bool
+from mg.dataset.tae_dataset import BatchTransform, VideoDataset
+from mg.models.tae import TemporalAutoencoder
+from mindone.data import create_dataloader
+from mindone.utils import init_train_env
from mindone.utils.misc import to_abspath
logger = logging.getLogger()
-def parse_train_args(parser):
+def parse_train_args():
+ parser = ArgumentParser(description="Temporal Autoencoder training script.")
parser.add_argument(
- "--config",
"-c",
- default="",
- type=str,
- help="path to load a config yaml file that describes the training recipes which will override the default arguments",
+ action=ActionConfigFile,
+ help="Path to load a config yaml file that describes the setting which will override the default arguments.",
)
- # the following args's defualt value will be overrided if specified in config yaml
-
- # data
- parser.add_argument("--dataset_name", default="", type=str, help="dataset name")
- parser.add_argument(
- "--csv_path",
- default="",
- type=str,
- help="path to csv annotation file. columns: video, caption. \
- video indicates the relative path of video file in video_folder. caption - the text caption for video",
+ parser.add_function_arguments(
+ init_train_env, skip={"ascend_config", "num_workers", "json_data_path", "enable_modelarts"}
)
- parser.add_argument("--video_column", default="video", type=str, help="name of column for videos saved in csv file")
- parser.add_argument("--random_crop", default=False, type=str2bool, help="randonly crop the image")
- parser.add_argument("--flip", default=False, type=str2bool, help="flip the image")
-
- parser.add_argument(
- "--caption_column", default="caption", type=str, help="name of column for captions saved in csv file"
+ parser.add_class_arguments(TemporalAutoencoder, instantiate=False)
+ parser.add_class_arguments(VideoDataset, skip={"output_columns"}, instantiate=False)
+ parser.add_class_arguments(BatchTransform, instantiate=False)
+ parser.add_function_arguments(
+ create_dataloader,
+ skip={"dataset", "transforms", "batch_transforms", "device_num", "rank_id", "debug", "enable_modelarts"},
)
- parser.add_argument("--video_folder", default="", type=str, help="root dir for the video data")
parser.add_argument("--output_path", default="output/", type=str, help="output directory to save training results")
parser.add_argument(
"--add_datetime", default=True, type=str, help="If True, add datetime subfolder under output_path"
)
# model
- parser.add_argument("--model_type", default="OpenSora-VAE-v1.2", type=str, help="VAE model type")
- parser.add_argument("--freeze_vae_2d", default=True, type=str2bool, help="Freeze 2d vae")
- parser.add_argument(
- "--use_discriminator", default=False, type=str2bool, help="Use discriminator for adversarial training."
- )
- parser.add_argument(
- "--pretrained_model_path",
- default="",
- type=str,
- help="Specify the pretrained model path",
- )
parser.add_argument("--perceptual_loss_weight", default=0.1, type=float, help="perceptual (lpips) loss weight")
parser.add_argument("--kl_loss_weight", default=1.0e-6, type=float, help="KL loss weight")
parser.add_argument(
"--use_outlier_penalty_loss",
default=False,
- type=str2bool,
+ type=bool,
help="use outlier penalty loss",
)
- # data
- parser.add_argument("--mixed_strategy", type=str, default=None, help="video and image mixed strategy")
- parser.add_argument(
- "--mixed_image_ratio", default=0.0, type=float, help="image ratio in mixed video and image data training"
- )
-
- # ms
- parser.add_argument("--debug", type=str2bool, default=False, help="Execute inference in debug mode.")
- parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU")
- parser.add_argument("--max_device_memory", type=str, default=None, help="e.g. `30GB` for 910a, `59GB` for 910b")
- parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode")
- parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel")
- parser.add_argument(
- "--parallel_mode", default="data", type=str, choices=["data", "optim"], help="parallel mode: data, optim"
- )
- parser.add_argument("--jit_level", default="O0", type=str, help="O0 kbk, O1 dvm, O2 ge")
-
# training hyper-params
parser.add_argument(
"--resume",
@@ -110,32 +76,18 @@ def parse_train_args(parser):
If None, filter list is [layernorm, bias], Default: None",
)
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay.")
- parser.add_argument("--seed", default=3407, type=int, help="data path")
parser.add_argument("--warmup_steps", default=1000, type=int, help="warmup steps")
- parser.add_argument("--batch_size", default=10, type=int, help="batch size")
- parser.add_argument(
- "--micro_batch_size",
- type=int,
- default=4,
- help="If not None, split batch_size*num_frames into smaller ones for VAE encoding to reduce memory limitation",
- )
- parser.add_argument(
- "--micro_frame_size",
- type=int,
- default=17,
- help="If not None, split batch_size*num_frames into smaller ones for VAE encoding to reduce memory limitation. Used by temporal vae",
- )
parser.add_argument("--start_learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--end_learning_rate", default=1e-7, type=float, help="The end learning rate for Adam.")
parser.add_argument(
- "--scale_lr", default=False, type=str2bool, help="scale base-lr by ngpu * batch_size * n_accumulate"
+ "--scale_lr", default=False, type=bool, help="scale base-lr by ngpu * batch_size * n_accumulate"
)
parser.add_argument("--decay_steps", default=0, type=int, help="lr decay steps.")
parser.add_argument("--scheduler", default="cosine_decay", type=str, help="scheduler.")
- parser.add_argument("--pre_patchify", default=False, type=str2bool, help="Training with patchified latent.")
+ parser.add_argument("--pre_patchify", default=False, type=bool, help="Training with patchified latent.")
# dataloader params
- parser.add_argument("--dataset_sink_mode", default=False, type=str2bool, help="sink mode")
+ parser.add_argument("--dataset_sink_mode", default=False, type=bool, help="sink mode")
parser.add_argument("--sink_size", default=-1, type=int, help="dataset sink size. If -1, sink size = dataset size.")
parser.add_argument(
"--epochs",
@@ -150,94 +102,29 @@ def parse_train_args(parser):
parser.add_argument("--loss_scale_factor", default=2, type=float, help="loss scale factor")
parser.add_argument("--scale_window", default=2000, type=float, help="scale window")
parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="gradient accumulation steps")
- # parser.add_argument("--cond_stage_trainable", default=False, type=str2bool, help="whether text encoder is trainable")
- parser.add_argument("--use_ema", default=False, type=str2bool, help="whether use EMA")
+ # parser.add_argument("--cond_stage_trainable", default=False, type=bool, help="whether text encoder is trainable")
+ parser.add_argument("--use_ema", default=False, type=bool, help="whether use EMA")
parser.add_argument("--ema_decay", default=0.9999, type=float, help="ema decay ratio")
- parser.add_argument("--clip_grad", default=False, type=str2bool, help="whether apply gradient clipping")
- parser.add_argument(
- "--use_recompute",
- default=False,
- type=str2bool,
- help="whether use recompute.",
- )
- parser.add_argument(
- "--num_recompute_blocks",
- default=None,
- type=int,
- help="If None, all stdit blocks will be applied with recompute (gradient checkpointing). If int, the first N blocks will be applied with recompute",
- )
- parser.add_argument(
- "--dtype",
- default="fp16",
- type=str,
- choices=["bf16", "fp16", "fp32"],
- help="what computation data type to use for latte. Default is `fp16`, which corresponds to ms.float16",
- )
+ parser.add_argument("--clip_grad", default=False, type=bool, help="whether apply gradient clipping")
parser.add_argument(
"--vae_keep_gn_fp32",
default=True,
- type=str2bool,
+ type=bool,
help="whether keep GroupNorm in fp32.",
)
parser.add_argument(
"--vae_keep_updown_fp32",
default=True,
- type=str2bool,
+ type=bool,
help="whether keep spatial/temporal upsample and downsample in fp32.",
)
- parser.add_argument(
- "--global_bf16",
- default=False,
- type=str2bool,
- help="Experimental. If True, dtype will be overrided, operators will be computered in bf16 if they are supported by CANN",
- )
- parser.add_argument(
- "--vae_param_dtype",
- default="fp32",
- type=str,
- choices=["bf16", "fp16", "fp32"],
- help="what param data type to use for vae. Default is `fp32`, which corresponds to ms.float32",
- )
- parser.add_argument(
- "--amp_level",
- default="O2",
- type=str,
- help="mindspore amp level, O1: most fp32, only layers in whitelist compute in fp16 (dense, conv, etc); \
- O2: most fp16, only layers in blacklist compute in fp32 (batch norm etc)",
- )
- parser.add_argument("--vae_amp_level", default="O2", type=str, help="O2 or O3")
- parser.add_argument(
- "--vae_checkpoint",
- type=str,
- default="models/sd-vae-ft-ema.ckpt",
- help="VAE checkpoint file path which is used to load vae weight.",
- )
- parser.add_argument(
- "--sd_scale_factor", type=float, default=0.18215, help="VAE scale factor of Stable Diffusion model."
- )
- parser.add_argument(
- "--image_size", default=256, type=int, nargs="+", help="image size for resizing the input image"
- )
- parser.add_argument("--crop_size", default=256, type=int, help="crop size after resize")
- parser.add_argument("--num_frames", default=16, type=int, help="the num of frames used to initiate model")
- parser.add_argument("--frame_stride", default=3, type=int, help="frame sampling stride")
- parser.add_argument("--mask_ratios", type=dict, help="Masking ratios")
- parser.add_argument("--bucket_config", type=dict, help="Multi-resolution bucketing configuration")
- parser.add_argument("--num_parallel_workers", default=12, type=int, help="num workers for data loading")
- parser.add_argument(
- "--data_multiprocessing",
- default=False,
- type=str2bool,
- help="If True, use multiprocessing for data processing. Default: multithreading.",
- )
- parser.add_argument("--max_rowsize", default=64, type=int, help="max rowsize for data loading")
parser.add_argument(
"--enable_flash_attention",
default=None,
- type=str2bool,
+ type=bool,
help="whether to enable flash attention.",
)
- parser.add_argument("--drop_overflow_update", default=True, type=str2bool, help="drop overflow update")
+ parser.add_argument("--drop_overflow_update", default=True, type=bool, help="drop overflow update")
parser.add_argument("--loss_scaler_type", default="dynamic", type=str, help="dynamic or static")
parser.add_argument(
"--max_grad_norm",
@@ -256,10 +143,10 @@ def parse_train_args(parser):
parser.add_argument(
"--step_mode",
default=False,
- type=str2bool,
+ type=bool,
help="whether save ckpt by steps. If False, save ckpt by epochs.",
)
- parser.add_argument("--profile", default=False, type=str2bool, help="Profile or not")
+ parser.add_argument("--profile", default=False, type=bool, help="Profile or not")
parser.add_argument(
"--log_level",
type=str,
@@ -276,19 +163,12 @@ def parse_train_args(parser):
def parse_args():
- parser = argparse.ArgumentParser()
- parser = parse_train_args(parser)
+ parser = parse_train_args()
+ args = parser.parse_args()
__dir__ = os.path.dirname(os.path.abspath(__file__))
abs_path = os.path.abspath(os.path.join(__dir__, ".."))
- default_args = parser.parse_args()
- if default_args.config:
- default_args.config = to_abspath(abs_path, default_args.config)
- with open(default_args.config, "r") as f:
- cfg = yaml.safe_load(f)
- _check_cfgs_in_parser(cfg, parser)
- parser.set_defaults(**cfg)
- args = parser.parse_args()
+
# convert to absolute path, necessary for modelarts
args.csv_path = to_abspath(abs_path, args.csv_path)
args.video_folder = to_abspath(abs_path, args.video_folder)
diff --git a/examples/moviegen/scripts/inference_vae.py b/examples/moviegen/scripts/eval_tae.py
similarity index 51%
rename from examples/moviegen/scripts/inference_vae.py
rename to examples/moviegen/scripts/eval_tae.py
index c2323840f0..61cc7c50fd 100644
--- a/examples/moviegen/scripts/inference_vae.py
+++ b/examples/moviegen/scripts/eval_tae.py
@@ -1,8 +1,6 @@
-# flake8: noqa
"""
Infer and evaluate autoencoders
"""
-import argparse
import logging
import os
import sys
@@ -10,34 +8,29 @@
import imageio
import numpy as np
-
-from mindspore import nn, ops
-
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-mindone_dir = os.path.abspath(os.path.join(__dir__, "../../../"))
-sys.path.insert(0, mindone_dir)
-
-
-from omegaconf import OmegaConf
+from jsonargparse import ArgumentParser
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as calc_psnr
from skimage.metrics import structural_similarity as calc_ssim
from tqdm import tqdm
import mindspore as ms
+from mindspore import amp, nn, ops
+# TODO: remove in future when mindone is ready for install
__dir__ = os.path.dirname(os.path.abspath(__file__))
mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../"))
-sys.path.insert(0, mindone_lib_path)
-sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
+sys.path.append(mindone_lib_path)
+sys.path.append(os.path.join(__dir__, ".."))
-from mg.datasets.tae_dataset import create_dataloader
-from mg.models.tae.lpips import LPIPS
-from mg.models.tae.tae import TemporalAutoencoder
+from mg.dataset.tae_dataset import VideoDataset
+from mg.models.tae import TemporalAutoencoder
+
+from mindone.data import create_dataloader
+from mindone.utils import init_train_env, set_logger
+
+# from mg.models.tae.lpips import LPIPS
-from mindone.utils.amp import auto_mixed_precision
-from mindone.utils.config import instantiate_from_config, str2bool
-from mindone.utils.logger import set_logger
logger = logging.getLogger(__name__)
@@ -95,99 +88,61 @@ def rearrange_out(x, t):
def main(args):
- ascend_config = {"precision_mode": "must_keep_origin_dtype"}
- ms.set_context(mode=args.mode, ascend_config=ascend_config)
- ms.set_context(jit_config={"jit_level": args.jit_level})
- set_logger(name="", output_dir=args.output_path, rank=0)
+ # set env
+ # TODO: rename as train and infer are identical?
+ _, rank_id, device_num = init_train_env(mode=args.mode, ascend_config={"precision_mode": "must_keep_origin_dtype"})
+ set_logger(name="", output_dir=args.output_path, rank=rank_id)
# build model
- model = TemporalAutoencoder(
- pretrained=args.ckpt_path,
- use_tile=args.enable_tile,
- )
-
- model.set_train(False)
- logger.info(f"Loaded checkpoint from {args.ckpt_path}")
-
- if args.eval_loss:
- lpips_loss_fn = LPIPS()
-
+ model = TemporalAutoencoder(pretrained=args.pretrained, use_tile=args.use_tile).set_train(False)
if args.dtype != "fp32":
- amp_level = "O2"
dtype = {"fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype]
# FIXME: due to AvgPool and ops.interpolate doesn't support bf16, we add them to fp32 cells
custom_fp32_cells = [nn.GroupNorm, nn.AvgPool2d, nn.Upsample]
- model = auto_mixed_precision(model, amp_level, dtype, custom_fp32_cells)
+ model = amp.custom_mixed_precision(model, black_list=amp.get_black_list() + custom_fp32_cells, dtype=dtype)
logger.info(f"Set mixed precision to O2 with dtype={args.dtype}")
- else:
- amp_level = "O0"
- # build dataset
- if isinstance(args.image_size, int):
- image_size = args.image_size
- else:
- if len(args.image_size) == 2:
- assert args.image_size[0] == args.image_size[1], "Currently only h==w is supported"
- image_size = args.image_size[0]
+ # if args.eval_loss:
+ # lpips_loss_fn = LPIPS()
- ds_config = dict(
+ # build dataset
+ dataset = VideoDataset(
csv_path=args.csv_path,
- data_folder=args.video_folder,
- size=image_size,
- crop_size=image_size,
- sample_n_frames=args.num_frames,
- sample_stride=args.frame_stride,
+ folder=args.folder,
+ size=args.image_size,
+ crop_size=args.image_size,
+ sample_n_frames=args.sample_n_frames,
+ sample_stride=args.sample_stride,
video_column=args.video_column,
random_crop=False,
flip=False,
+ output_columns=["video"],
)
dataset = create_dataloader(
- ds_config,
+ dataset,
args.batch_size,
- mixed_strategy=None,
- mixed_image_ratio=0.0,
- num_parallel_workers=8,
+ num_workers=8,
max_rowsize=256,
shuffle=False,
- device_num=1,
- rank_id=0,
+ device_num=device_num,
+ rank_id=rank_id,
drop_remainder=False,
)
num_batches = dataset.get_dataset_size()
- ds_iter = dataset.create_dict_iterator(1)
+ ds_iter = dataset.create_dict_iterator(num_epochs=1)
- if args.dynamic_shape:
- videos = ms.Tensor(shape=[None, 3, None, 256, 256], dtype=ms.float32)
- model.set_inputs(videos)
-
- logger.info("Inferene begins")
- mean_infer_time = 0
- mean_psnr = 0
- mean_ssim = 0
- mean_lpips = 0
- mean_recon = 0
- num_samples = 0
+ mean_infer_time, mean_psnr, mean_ssim, mean_lpips, mean_recon, num_samples = (0,) * 6
for step, data in tqdm(enumerate(ds_iter)):
x = data["video"]
- start_time = time.time()
-
- # debug
- # if args.dynamic_shape:
- # if step % 2 == 0:
- # x = x[:, :, : x.shape[2]//2]
- # print('x shape: ', x.shape)
+ start_time = time.perf_counter()
if args.encode_only:
- z = model.encode(x)
+ z, posterior_mean, posterior_logvar = model.encode(x)
else:
- # recons = model.decode(z)
recons, z, posterior_mean, posterior_logvar = model(x)
- # adapt to bf16
- recons = recons.to(ms.float32)
-
- infer_time = time.time() - start_time
+ infer_time = time.perf_counter() - start_time
mean_infer_time += infer_time
logger.info(f"Infer time: {infer_time}")
@@ -226,9 +181,7 @@ def main(args):
logger.info(f"mean recon loss: {mean_recon/num_batches:.4f}")
if args.save_vis:
- save_fn = os.path.join(
- args.output_path, "{}-{}".format(os.path.basename(args.video_folder), f"step{step:03d}")
- )
+ save_fn = os.path.join(args.output_path, f"{os.path.basename(args.video_folder)}-{f'step{step:03d}'}")
if not is_video:
visualize_image(recons_rgb, x_rgb, save_fn=save_fn)
else:
@@ -254,91 +207,40 @@ def main(args):
# logger.info(f"mean lpips loss: {mean_lpips:.4f}")
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--model_config",
- default="configs/autoencoder_kl_f8.yaml",
- type=str,
- help="model architecture config",
- )
- parser.add_argument(
- "--ckpt_path", default="outputs/vae_train/ckpt/vae_kl_f8-e10.ckpt", type=str, help="checkpoint path"
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_function_arguments(
+ init_train_env, skip={"ascend_config", "num_workers", "json_data_path", "enable_modelarts"}
)
- parser.add_argument(
- "--csv_path",
- default=None,
- type=str,
- help="path to csv annotation file. If None, will get videos from the folder of `data_path`",
+ parser.add_class_arguments(TemporalAutoencoder, instantiate=False)
+ parser.add_class_arguments(VideoDataset, skip={"output_columns"}, instantiate=False)
+ parser.add_function_arguments(
+ create_dataloader,
+ skip={"dataset", "transforms", "batch_transforms", "device_num", "rank_id", "debug", "enable_modelarts"},
)
- parser.add_argument("--video_folder", default=None, type=str, help="folder of videos")
parser.add_argument(
"--output_path", default="samples/vae_recons", type=str, help="output directory to save inference results"
)
- parser.add_argument("--num_frames", default=17, type=int, help="num frames")
- parser.add_argument("--frame_stride", default=1, type=int, help="frame sampling stride")
parser.add_argument(
"--expand_dim_t",
default=False,
- type=str2bool,
+ type=bool,
help="expand temporal axis for image data, used for vae 3d inference with image data",
)
- parser.add_argument("--image_size", default=256, type=int, help="image rescale size")
- # parser.add_argument("--crop_size", default=256, type=int, help="image crop size")
-
- parser.add_argument("--batch_size", default=1, type=int, help="batch size")
- parser.add_argument("--num_parallel_workers", default=8, type=int, help="num workers for data loading")
parser.add_argument(
"--eval_loss",
default=False,
- type=str2bool,
+ type=bool,
help="whether measure loss including reconstruction, kl, perceptual loss",
)
- parser.add_argument("--save_vis", default=True, type=str2bool, help="whether save reconstructed images")
- parser.add_argument("--use_temporal_vae", default=True, type=str2bool, help="if False, just use spatial vae")
- parser.add_argument("--encode_only", default=False, type=str2bool, help="only encode to save z or distribution")
- parser.add_argument(
- "--enable_tile", default=False, type=str2bool, help="enable temporal tiling with linear blending for decoder"
- )
- parser.add_argument("--video_column", default="video", type=str, help="name of column for videos saved in csv file")
- parser.add_argument(
- "--mixed_strategy",
- type=str,
- default=None,
- choices=[None, "mixed_video_image", "image_only"],
- help="video and image mixed strategy.",
- )
- parser.add_argument(
- "--mixed_image_ratio", default=0.0, type=float, help="image ratio in mixed video and image data training"
- )
+ parser.add_argument("--save_vis", default=True, type=bool, help="whether save reconstructed images")
+ parser.add_argument("--use_temporal_vae", default=True, type=bool, help="if False, just use spatial vae")
+ parser.add_argument("--encode_only", default=False, type=bool, help="only encode to save z or distribution")
parser.add_argument(
"--save_z_dist",
default=False,
- type=str2bool,
+ type=bool,
help="If True, save z distribution, mean and logvar. Otherwise, save z after sampling.",
)
- parser.add_argument(
- "--dynamic_shape", default=False, type=str2bool, help="whether input shape to the network is dynamic"
- )
-
- # ms related
- parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode")
- parser.add_argument("--jit_level", default="O0", type=str, help="O0 kbk, O1 dvm, O2 ge")
- parser.add_argument(
- "--dtype",
- default="fp32",
- type=str,
- choices=["fp32", "fp16", "bf16"],
- help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \
- if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.",
- )
- parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU")
-
args = parser.parse_args()
-
- return args
-
-
-if __name__ == "__main__":
- args = parse_args()
main(args)
diff --git a/examples/moviegen/scripts/gradio_demo.py b/examples/moviegen/scripts/gradio_demo.py
new file mode 100644
index 0000000000..80f304fcfe
--- /dev/null
+++ b/examples/moviegen/scripts/gradio_demo.py
@@ -0,0 +1,229 @@
+import datetime
+import glob
+import logging
+import os
+import sys
+import time
+from typing import List, Tuple
+
+import gradio as gr
+import numpy as np
+from jsonargparse import ActionConfigFile, ArgumentParser
+from jsonargparse.typing import path_type
+
+import mindspore as ms
+from mindspore import amp, nn
+
+# TODO: remove in future when mindone is ready for install
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../"))
+sys.path.append(mindone_lib_path)
+sys.path.append(os.path.join(__dir__, ".."))
+
+from mg.models.tae import TemporalAutoencoder
+from mg.pipelines import InferPipeline
+from mg.utils import init_model, to_numpy
+
+from mindone.utils import init_train_env, set_logger
+from mindone.visualize import save_videos
+
+logger = logging.getLogger(__name__)
+
+
+def prepare_captions(
+ ul2_dir: str, metaclip_dir: str, byt5_dir: str, rank_id: int = 0, device_num: int = 1
+) -> Tuple[List[str], List[str], List[str]]:
+ """Prepare caption embeddings from specified directories"""
+ ul2_emb = sorted(glob.glob(os.path.join(ul2_dir, "*.npz")))
+ metaclip_emb = sorted(glob.glob(os.path.join(metaclip_dir, "*.npz")))
+ byt5_emb = sorted(glob.glob(os.path.join(byt5_dir, "*.npz")))
+
+ if len(ul2_emb) != len(byt5_emb):
+ raise ValueError(
+ f"ul2_dir ({len(ul2_emb)}), metaclip_dir ({len(metaclip_emb)}), "
+ f" and byt5_dir ({len(byt5_emb)}) must contain the same number of files"
+ )
+
+ ul2_emb = ul2_emb[rank_id::device_num]
+ logger.info(f"Number of captions for rank {rank_id}: {len(ul2_emb)}")
+ return ul2_emb, metaclip_emb[rank_id::device_num], byt5_emb[rank_id::device_num]
+
+
+def load_embeddings(selected_prompts: List[str], args) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor]:
+ """Load embeddings for selected prompts matching original implementation"""
+ # Get full paths for selected prompts
+ # print(selected_prompts)
+ ul2_files = os.path.join(args.text_emb.ul2_dir, f"{selected_prompts}.npz")
+ byt5_files = os.path.join(args.text_emb.byt5_dir, f"{selected_prompts}.npz")
+
+ # Load embeddings in batch
+ ul2_emb = ms.Tensor(np.load(ul2_files)["text_emb"], dtype=ms.float32)
+ byt5_emb = ms.Tensor(np.load(byt5_files)["text_emb"], dtype=ms.float32)
+ ul2_emb = ul2_emb.unsqueeze(0)
+ byt5_emb = byt5_emb.unsqueeze(0)
+
+ # Create placeholder metaclip embedding matching batch size
+ metaclip_emb = ms.Tensor(np.ones((ul2_emb.shape[0], 300, 1280)), dtype=ms.float32)
+ return ul2_emb, metaclip_emb, byt5_emb
+
+
+def init_models(args):
+ """Initialize MovieGen models with specified configurations"""
+ # Initialize TAE
+ logger.info("Initializing TAE...")
+ tae = TemporalAutoencoder(**args.tae).set_train(False)
+ if tae.dtype != ms.float32:
+ amp.custom_mixed_precision(
+ tae, black_list=amp.get_black_list() + [nn.GroupNorm, nn.AvgPool2d, nn.Upsample], dtype=tae.dtype
+ )
+
+ # Initialize Transformer model
+ logger.info("Initializing Transformer model...")
+ model = init_model(in_channels=tae.out_channels, **args.model).set_train(False)
+
+ return model, tae
+
+
+def create_pipeline(model, tae, args):
+ """Create MovieGen inference pipeline"""
+ img_h, img_w = args.image_size if isinstance(args.image_size, list) else (args.image_size, args.image_size)
+ latent_size = tae.get_latent_size((args.num_frames, img_h, img_w))
+
+ return InferPipeline(
+ model,
+ tae,
+ latent_size,
+ guidance_scale=args.guidance_scale,
+ num_sampling_steps=args.num_sampling_steps,
+ sample_method=args.sample_method,
+ micro_batch_size=args.micro_batch_size,
+ )
+
+
+def generate_video(selected_prompts: List[str], args, pipeline, progress=gr.Progress()) -> List[str]:
+ """Generate videos for selected prompts"""
+ progress(0.1, "Loading embeddings...")
+ ul2_emb, metaclip_emb, byt5_emb = load_embeddings(selected_prompts, args)
+
+ progress(0.2, "Generating videos...")
+ start_time = time.perf_counter()
+ sample, latent = pipeline(
+ ul2_emb=ul2_emb,
+ metaclip_emb=metaclip_emb,
+ byt5_emb=byt5_emb,
+ num_frames=args.num_frames,
+ )
+ # import pdb
+ # pdb.set_trace()
+ generation_time = time.perf_counter() - start_time
+
+ progress(0.8, "Saving videos...")
+ save_dir = os.path.join(args.output_path, "gradio_samples")
+ if args.append_timestamp:
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+ save_dir = os.path.join(save_dir, time_str)
+ os.makedirs(save_dir, exist_ok=True)
+
+ output_files = []
+ # for i, prompt in enumerate(selected_prompts):
+ output_file = os.path.join(save_dir, f"{selected_prompts}.{args.save_format}")
+ save_videos(to_numpy(sample[0]), output_file, fps=args.fps)
+ output_files.append(output_file)
+
+ logger.info(
+ f"Videos generated in {generation_time: .2f}s "
+ f"({args.num_sampling_steps * len(selected_prompts) / generation_time: .2f} steps/s)"
+ )
+
+ return output_files
+
+
+def create_demo(args):
+ """Create and configure Gradio interface"""
+ # Initialize models and pipeline
+ model, tae = init_models(args)
+ pipeline = create_pipeline(model, tae, args)
+
+ # Get available prompts
+ ul2_emb, _, _ = prepare_captions(**args.text_emb)
+ prompts = [os.path.basename(p)[:-4] for p in ul2_emb]
+
+ # Create Gradio interface
+ with gr.Blocks() as demo:
+ gr.Markdown("# MovieGen Video Generation Demo")
+ gr.Markdown(f"Model: {args.model.name}")
+
+ with gr.Row():
+ with gr.Column():
+ prompt = gr.Dropdown(
+ choices=prompts,
+ label="Select Pre-computed Prompt",
+ info="Choose from available pre-computed prompts",
+ )
+ generate_btn = gr.Button("Generate Video", variant="primary")
+
+ with gr.Column():
+ video_output = gr.Video(label="Generated Video")
+ info_box = gr.Textbox(label="Generation Info", interactive=False)
+
+ def generate_and_log(prompt_name):
+ print("Prompt name ", prompt_name)
+ output_file = generate_video(prompt_name, args, pipeline)
+ info = f"Successfully generated video for prompt: {prompt_name}"
+ return output_file[0], info
+
+ generate_btn.click(
+ fn=generate_and_log,
+ inputs=[prompt],
+ outputs=[video_output, info_box],
+ )
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser(description="MovieGen Gradio demo")
+ parser.add_argument(
+ "-c",
+ "--config",
+ action=ActionConfigFile,
+ help="Path to MovieGen config file",
+ )
+
+ # Add all necessary arguments
+ parser.add_function_arguments(init_train_env, "env")
+ parser.add_function_arguments(init_model, "model", skip={"in_channels"})
+
+ # TAE parameters
+ tae_group = parser.add_argument_group("TAE parameters")
+ tae_group.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False)
+
+ # Inference parameters
+ infer_group = parser.add_argument_group("Inference parameters")
+ infer_group.add_class_arguments(InferPipeline, skip={"model", "tae", "latent_size"}, instantiate=False)
+ infer_group.add_argument("--image_size", type=int, nargs="+", default=[256, 455])
+ infer_group.add_argument("--num_frames", type=int, default=32)
+ infer_group.add_argument("--fps", type=int, default=16)
+ infer_group.add_function_arguments(prepare_captions, "text_emb", skip={"rank_id", "device_num"})
+ infer_group.add_argument("--batch_size", type=int, default=2)
+
+ # Save options
+ save_group = parser.add_argument_group("Saving options")
+ save_group.add_argument("--save_format", default="mp4", choices=["gif", "mp4", "png"])
+ save_group.add_argument("--output_path", default="output/", type=path_type("dcc"))
+ save_group.add_argument("--append_timestamp", type=bool, default=True)
+ save_group.add_argument(
+ "--save_latent",
+ type=bool,
+ default=False,
+ help="Save denoised video latent. If True, the denoised latents will be saved in $output_path/denoised_latents",
+ )
+ args = parser.parse_args()
+
+ # Set up logging
+ os.makedirs(os.path.join(args.output_path, "logs"), exist_ok=True)
+ set_logger(name="", output_dir=os.path.join(args.output_path, "logs"))
+
+ # Create and launch demo
+ demo = create_demo(args)
+ demo.launch()
diff --git a/examples/moviegen/scripts/inference.py b/examples/moviegen/scripts/inference.py
new file mode 100644
index 0000000000..ccf7d6a0db
--- /dev/null
+++ b/examples/moviegen/scripts/inference.py
@@ -0,0 +1,231 @@
+import datetime
+import glob
+import logging
+import os
+import sys
+import time
+from typing import List, Tuple
+
+import numpy as np
+from jsonargparse import ActionConfigFile, ArgumentParser
+from jsonargparse.typing import path_type
+from mg.acceleration import set_sequence_parallel_group
+
+import mindspore as ms
+from mindspore import amp, nn
+from mindspore.communication import GlobalComm
+
+# TODO: remove in future when mindone is ready for install
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../"))
+sys.path.append(mindone_lib_path)
+sys.path.append(os.path.join(__dir__, ".."))
+
+from mg.models.tae import TemporalAutoencoder
+from mg.pipelines import InferPipeline
+from mg.utils import init_model, to_numpy
+
+from mindone.utils import init_train_env, set_logger
+from mindone.visualize import save_videos
+
+logger = logging.getLogger(__name__)
+
+Path_dr = path_type("dr", docstring="path to a directory that exists and is readable")
+
+
+def prepare_captions(
+ ul2_dir: Path_dr, metaclip_dir: Path_dr, byt5_dir: Path_dr, rank_id: int, device_num: int, enable_sp: bool = False
+) -> Tuple[List[str], List[str], List[str]]:
+ ul2_emb = sorted(glob.glob(os.path.join(ul2_dir, "*.npz")))
+ metaclip_emb = sorted(glob.glob(os.path.join(metaclip_dir, "*.npz")))
+ byt5_emb = sorted(glob.glob(os.path.join(byt5_dir, "*.npz")))
+ if len(ul2_emb) != len(metaclip_emb) or len(ul2_emb) != len(byt5_emb):
+ raise ValueError(
+ f"ul2_dir ({len(ul2_emb)}), metaclip_dir ({len(metaclip_emb)}),"
+ f" and byt5_dir ({len(byt5_emb)}) must contain the same number of files"
+ )
+ if enable_sp:
+ logger.info(f"Sequence parallel is enabled, loading all captions to all ranks: {len(ul2_emb)} captions")
+ return ul2_emb, metaclip_emb, byt5_emb
+ else:
+ ul2_emb = ul2_emb[rank_id::device_num]
+ logger.info(f"Number of captions for rank {rank_id}: {len(ul2_emb)}")
+ return ul2_emb, metaclip_emb[rank_id::device_num], byt5_emb[rank_id::device_num]
+
+
+def main(args):
+ # TODO: CFG error
+ save_dir = os.path.abspath(args.output_path)
+ if args.append_timestamp:
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+ save_dir = os.path.join(save_dir, time_str)
+ os.makedirs(save_dir, exist_ok=True)
+ set_logger(name="", output_dir=save_dir)
+
+ latent_dir = os.path.join(save_dir, "denoised_latents")
+ if args.save_latent:
+ os.makedirs(latent_dir, exist_ok=True)
+
+ # 1. init env
+ _, rank_id, device_num = init_train_env(**args.env) # TODO: rename as train and infer are identical?
+
+ if args.enable_sequence_parallel:
+ set_sequence_parallel_group(GlobalComm.WORLD_COMM_GROUP)
+
+ # 1.1 read caption embeddings
+ ul2_emb, metaclip_emb, byt5_emb = prepare_captions(
+ **args.text_emb, rank_id=rank_id, device_num=device_num, enable_sp=args.enable_sequence_parallel
+ )
+
+ # 2. model initiate and weight loading
+ # 2.1 tae
+ if args.tae is not None:
+ logger.info("Initializing TAE...")
+ tae = TemporalAutoencoder(**args.tae.init_args).set_train(False)
+ if tae.dtype != ms.float32:
+ # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative
+ amp.custom_mixed_precision(
+ tae, black_list=amp.get_black_list() + [nn.GroupNorm, nn.AvgPool2d, nn.Upsample], dtype=tae.dtype
+ )
+ if args.model.in_channels != tae.out_channels:
+ logger.warning(
+ f"The number of model input channels ({args.model.in_channels}) doesn't match the number of TAE output"
+ f" channels ({tae.out_channels}). Setting it to {tae.out_channels}."
+ )
+ args.model.in_channels = tae.out_channels
+ else:
+ logger.info("Skipping TAE initialization.")
+ tae = None
+
+ img_h, img_w = args.image_size if isinstance(args.image_size, list) else (args.image_size, args.image_size)
+ num_frames = args.num_frames
+ latent_size = TemporalAutoencoder.get_latent_size((num_frames, img_h, img_w))
+
+ # 2.2 Llama 3
+ logger.info("Transformer init")
+ model = init_model(**args.model).set_train(False)
+
+ # 2.3 text embeddings
+ prompt_prefix = [os.path.basename(emb)[:-4] for emb in ul2_emb]
+ ul2_emb = ms.Tensor([np.load(emb)["text_emb"] for emb in ul2_emb], dtype=ms.float32)
+ # metaclip_emb = ms.Tensor([np.load(emb)["text_emb"] for emb in metaclip_emb], dtype=ms.float32)
+ metaclip_emb = ms.Tensor(np.ones((ul2_emb.shape[0], 300, 1280)), dtype=ms.float32) # FIXME: replace with actual
+ byt5_emb = ms.Tensor([np.load(emb)["text_emb"] for emb in byt5_emb], dtype=ms.float32)
+ num_prompts = ul2_emb.shape[0]
+
+ # 3. build inference pipeline
+ pipeline = InferPipeline(
+ model,
+ tae,
+ latent_size,
+ guidance_scale=args.guidance_scale,
+ num_sampling_steps=args.num_sampling_steps,
+ sample_method=args.sample_method,
+ micro_batch_size=args.micro_batch_size,
+ )
+
+ # 4. print key info
+ key_info = "Key Settings:\n" + "=" * 50 + "\n"
+ key_info += "\n".join(
+ [
+ f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.env.mode}",
+ f"Num of captions: {num_prompts}",
+ f"Model dtype: {args.model.dtype}",
+ f"TAE dtype: {args.tae.init_args.dtype if tae is not None else 'N/A'}",
+ f"Image size: {(img_h, img_w)}",
+ f"Num frames: {num_frames}",
+ f"Sampling steps {args.num_sampling_steps}",
+ f"CFG guidance scale: {args.guidance_scale}",
+ ]
+ )
+ key_info += "\n" + "=" * 50
+ logger.info(key_info)
+
+ for i in range(0, num_prompts, args.batch_size):
+ end_i = min(i + args.batch_size, num_prompts)
+ logger.info("Sampling captions:")
+ for j in range(i, end_i):
+ logger.info(prompt_prefix[j])
+
+ # infer
+ start_time = time.perf_counter()
+ sample, latent = pipeline(
+ ul2_emb=ul2_emb[i:end_i],
+ metaclip_emb=metaclip_emb[i:end_i],
+ byt5_emb=byt5_emb[i:end_i],
+ num_frames=num_frames,
+ )
+ batch_time = time.perf_counter() - start_time
+ logger.info(
+ f"Batch time cost: {batch_time:.3f}s,"
+ f" sampling speed: {args.num_sampling_steps * (end_i - i) / batch_time:.2f} step/s"
+ )
+
+ if args.enable_sequence_parallel and rank_id > 0:
+ # in sequence parallel mode, results from all ranks are identical, so save results only from rank 0
+ continue
+
+ # save results
+ for j in range(0, end_i - i):
+ fn = prompt_prefix[i + j]
+ save_fp = f"{save_dir}/{fn}.{args.save_format}"
+ latent_save_fp = f"{latent_dir}/{fn}.npy"
+
+ # save videos
+ if sample is not None:
+ save_videos(to_numpy(sample[j]), save_fp, fps=args.fps)
+ logger.info(f"Video saved in {save_fp}")
+ # save decoded latents
+ if args.save_latent:
+ np.save(latent_save_fp, to_numpy(latent[j]))
+ logger.info(f"Denoised latents saved in {latent_save_fp}")
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser(description="Movie Gen inference script.")
+ parser.add_argument(
+ "-c",
+ "--config",
+ action=ActionConfigFile,
+ help="Path to load a config yaml file that describes the setting which will override the default arguments.",
+ )
+ parser.add_function_arguments(init_train_env, "env")
+ parser.add_function_arguments(init_model, "model", skip={"resume"})
+ tae_group = parser.add_argument_group("TAE parameters")
+ tae_group.add_subclass_arguments(TemporalAutoencoder, "tae", instantiate=False, required=False)
+ infer_group = parser.add_argument_group("Inference parameters")
+ infer_group.add_class_arguments(InferPipeline, skip={"model", "tae", "latent_size"}, instantiate=False)
+ infer_group.add_argument("--image_size", type=int, nargs="+", help="Output video size")
+ infer_group.add_argument("--num_frames", type=int, default=16, help="number of frames")
+ infer_group.add_argument("--fps", type=int, default=16, help="FPS in the saved video")
+ infer_group.add_function_arguments(prepare_captions, "text_emb", skip={"rank_id", "device_num", "enable_sp"})
+ infer_group.add_argument("--batch_size", type=int, default=1)
+ infer_group.add_argument("--enable_sequence_parallel", type=bool, default=False, help="enable sequence parallel.")
+ save_group = parser.add_argument_group("Saving options")
+ save_group.add_argument(
+ "--save_format",
+ default="mp4",
+ choices=["gif", "mp4", "png"],
+ type=str,
+ help="video format for saving the sampling output: gif, mp4 or png",
+ )
+ save_group.add_argument(
+ "--output_path",
+ default="output/",
+ type=path_type("dcc"), # path to a directory that can be created if it does not exist
+ help="Output directory to save training results.",
+ )
+ save_group.add_argument(
+ "--append_timestamp",
+ type=bool,
+ default=True,
+ help="If true, a subfolder named with timestamp under output_path will be created to save the sampling results",
+ )
+ save_group.add_argument(
+ "--save_latent",
+ type=bool,
+ default=False,
+ help="Save denoised video latent. If True, the denoised latents will be saved in $output_path/denoised_latents",
+ )
+ cfg = parser.parse_args()
+ main(cfg)
diff --git a/examples/moviegen/scripts/inference_tae.py b/examples/moviegen/scripts/inference_tae.py
new file mode 100644
index 0000000000..854afe3b21
--- /dev/null
+++ b/examples/moviegen/scripts/inference_tae.py
@@ -0,0 +1,190 @@
+import glob
+import logging
+import os
+import sys
+from math import ceil
+from pathlib import Path
+from typing import List, Optional
+
+import numpy as np
+from jsonargparse import ArgumentParser
+from jsonargparse.typing import path_type
+from tqdm import tqdm, trange
+
+from mindspore import Tensor, amp
+from mindspore import dtype as mstype
+from mindspore import get_context, nn
+
+# TODO: remove in future when mindone is ready for install
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../"))
+sys.path.append(mindone_lib_path)
+sys.path.append(os.path.join(__dir__, ".."))
+
+from mg.dataset.tae_dataset import VideoDataset
+from mg.models.tae import TemporalAutoencoder
+from mg.utils import to_numpy
+
+from mindone.data import create_dataloader
+from mindone.utils import init_train_env, set_logger
+from mindone.visualize import save_videos
+
+logger = logging.getLogger(__name__)
+
+Path_dr = path_type("dr", docstring="path to a directory that exists and is readable")
+
+
+def encode(args, tae: TemporalAutoencoder, save_dir: Path, rank_id: int, device_num: int, mode: int):
+ dataset = VideoDataset(
+ **args.video_data.init_args,
+ sample_n_frames=10**5, # read the full video, limitation of `albumentations` (i.e., `additional_targets`)
+ output_columns=["video", "rel_path"],
+ )
+ dataloader = create_dataloader(
+ dataset, drop_remainder=False, device_num=device_num, rank_id=rank_id, **args.dataloader
+ )
+
+ # print key info
+ key_info = "Key Settings:\n" + "=" * 50 + "\n"
+ key_info += "\n".join(
+ [
+ f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {mode}",
+ f"Debug mode: {args.env.debug}",
+ f"TAE dtype: {args.tae.dtype}",
+ f"Image size: {args.video_data.init_args.size}",
+ f"Crop size: {args.video_data.init_args.crop_size}",
+ f"Num of batches: {dataloader.get_dataset_size()}",
+ ]
+ )
+ key_info += "\n" + "=" * 50
+ logger.info(key_info)
+
+ for samples in tqdm(dataloader.create_tuple_iterator(num_epochs=1), total=dataloader.get_dataset_size()):
+ _, mean, logvar = tae.encode(samples[0])
+ mean, logvar = to_numpy(mean), to_numpy(logvar)
+ std = np.exp(0.5 * np.clip(logvar, -30.0, 20.0))
+
+ for m, s, path in zip(mean, std, samples[1].tolist()):
+ out_path = save_dir / path
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ np.savez(out_path.with_suffix(".npz"), latent_mean=m, latent_std=s)
+ logger.info(f"Completed. Latents saved in {save_dir}")
+
+
+def prepare_latents(folder: Optional[Path_dr] = None, rank_id: int = 0, device_num: int = 1) -> List[str]:
+ latents = sorted(glob.glob(os.path.join(folder, "*.npy")))
+ latents = latents[rank_id::device_num]
+ logger.info(f"Number of latents for rank {rank_id}: {len(latents)}")
+ return latents
+
+
+def decode(args, tae: TemporalAutoencoder, save_dir: Path, rank_id: int, device_num: int, mode: int):
+ latent_paths = prepare_latents(**args.latent_data, rank_id=rank_id, device_num=device_num)
+ batch_size = args.dataloader.batch_size
+
+ # print key info
+ latent_shape = np.load(latent_paths[0]).shape
+ key_info = "Key Settings:\n" + "=" * 50 + "\n"
+ key_info += "\n".join(
+ [
+ f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {mode}",
+ f"Debug mode: {args.env.debug}",
+ f"TAE dtype: {args.tae.dtype}",
+ f"Latent shape: {latent_shape}",
+ f"Num of batches: {ceil(len(latent_paths) / batch_size)}",
+ ]
+ )
+ key_info += "\n" + "=" * 50
+ logger.info(key_info)
+
+ for i in trange(0, len(latent_paths), batch_size):
+ lps = latent_paths[i : i + batch_size]
+ latents = np.stack([np.load(lp) for lp in lps])
+
+ latents = np.transpose(latents, (0, 2, 1, 3, 4)) # FIXME: remove this redundancy
+ latents = latents / tae.scale_factor + tae.shift_factor
+ videos = to_numpy(tae.decode(Tensor(latents), target_num_frames=args.num_frames))
+ videos = np.clip((videos + 1.0) / 2.0, 0.0, 1.0)
+ videos = np.transpose(videos, (0, 2, 3, 4, 1))
+
+ for lp, video in zip(lps, videos):
+ save_fp = save_dir / Path(lp).with_suffix("." + args.save_format).name
+ save_videos(video, str(save_fp), fps=args.fps)
+ logger.info(f"Completed. Videos saved in {save_dir}")
+
+
+def main(args):
+ # 1. init env
+ _, rank_id, device_num = init_train_env(**args.env) # TODO: rename as train and infer are identical?
+ mode = get_context("mode") # `init_train_env()` may change the mode during debugging
+
+ save_dir = Path(args.output_path.absolute)
+ save_dir.mkdir(parents=True, exist_ok=True)
+ set_logger(name="", output_dir=str(save_dir), rank=rank_id)
+
+ # 3. TAE initiate and weight loading
+ logger.info("Initializing TAE...")
+ tae = TemporalAutoencoder(**args.tae).set_train(False)
+ if tae.dtype != mstype.float32:
+ # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative
+ amp.custom_mixed_precision(
+ tae, black_list=amp.get_black_list() + [nn.GroupNorm, nn.AvgPool2d, nn.Upsample], dtype=tae.dtype
+ )
+ # TODO: add dynamic shape support
+
+ if args.video_data is not None:
+ logger.info("Encoding video data.")
+ encode(args, tae, save_dir, rank_id, device_num, mode)
+ elif args.latent_data.folder is not None:
+ logger.info("Decoding latent data.")
+ decode(args, tae, save_dir, rank_id, device_num, mode)
+ else:
+ raise ValueError("Either `video_data` or `latent_data` must be provided.")
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser(description="TAE inference script.")
+ parser.add_function_arguments(init_train_env, "env")
+ parser.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False)
+ parser.add_subclass_arguments(
+ VideoDataset,
+ "video_data",
+ skip={"random_crop", "flip", "sample_n_frames", "return_image", "output_columns"},
+ instantiate=False,
+ required=False,
+ )
+ parser.add_function_arguments(prepare_latents, "latent_data", skip={"rank_id", "device_num"})
+ parser.add_function_arguments(
+ create_dataloader,
+ "dataloader",
+ skip={
+ "dataset",
+ "transforms",
+ "batch_transforms",
+ "project_columns",
+ "shuffle",
+ "num_workers", # no transformations inside `.map()`
+ "drop_remainder",
+ "device_num",
+ "rank_id",
+ "enable_modelarts",
+ },
+ )
+ parser.link_arguments("env.debug", "dataloader.debug", apply_on="parse")
+ parser.add_argument("--num_frames", default=256, type=int, help="Number of in the output video.")
+ parser.add_argument("--fps", type=int, default=16, help="FPS in the saved video")
+ parser.add_argument(
+ "--save_format",
+ default="mp4",
+ choices=["gif", "mp4", "png"],
+ type=str,
+ help="video format for saving the sampling output: gif, mp4 or png",
+ )
+ parser.add_argument(
+ "--output_path",
+ default="output/",
+ type=path_type("dcc"), # path to a directory that can be created if it does not exist
+ help="Output directory to save training results.",
+ )
+ cfg = parser.parse_args()
+ main(cfg)
diff --git a/examples/moviegen/scripts/inference_text_enc.py b/examples/moviegen/scripts/inference_text_enc.py
new file mode 100644
index 0000000000..7dee7fcf91
--- /dev/null
+++ b/examples/moviegen/scripts/inference_text_enc.py
@@ -0,0 +1,129 @@
+import logging
+import os
+import sys
+from csv import DictReader
+from pathlib import Path
+from typing import List, Tuple
+
+import numpy as np
+from jsonargparse import ArgumentParser
+from jsonargparse.typing import Path_fr, path_type
+from tqdm import trange
+from transformers import AutoTokenizer
+
+import mindspore as ms
+
+# TODO: remove in future when mindone is ready for install
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../"))
+sys.path.append(mindone_lib_path)
+sys.path.append(os.path.join(__dir__, ".."))
+
+from mg.utils import MODEL_DTYPE, to_numpy
+
+from mindone.transformers.models.t5.modeling_t5 import T5EncoderModel
+from mindone.utils import init_train_env, set_logger
+
+logger = logging.getLogger(__name__)
+
+Path_dcc = path_type("dcc") # path to a directory that can be created if it does not exist
+
+
+def prepare_captions(
+ prompts_file: Path_fr,
+ output_path: Path_dcc,
+ column_names: Tuple[str, str] = ("video", "caption"),
+ rank_id: int = 0,
+ device_num: int = 1,
+) -> Tuple[List[Path], List[str]]:
+ """
+ Reads prompts from a file and returns a list of saving paths and a list of captions.
+
+ Args:
+ prompts_file: Path to the prompt file. Can be a csv file or a txt file.
+ output_path: Path to the output directory where the embeddings will be saved.
+ column_names: [CSV only] Tuple of column names for video paths and captions.
+ rank_id: Current rank id for distributed inference.
+ device_num: Number of devices used for distributed inference.
+
+ Returns:
+ A tuple containing a list of saving paths and a list of captions.
+ """
+ prompts_file = prompts_file.absolute
+ output_path = Path(output_path.absolute)
+ with open(prompts_file, "r", encoding="utf-8") as file:
+ if prompts_file.endswith(".csv"):
+ paths, captions = zip(
+ *[
+ (output_path / Path(row[column_names[0]]).with_suffix(".npz"), row[column_names[1]])
+ for row in DictReader(file)
+ ]
+ )
+ return paths[rank_id::device_num], captions[rank_id::device_num]
+ else:
+ captions = [line.strip() for line in file] # preserve empty lines
+ paths = [
+ output_path / (f"{i:03d}-" + "-".join(Path(cap).stem.split(" ")[:10]) + ".npz")
+ for i, cap in enumerate(captions)
+ ]
+ return paths[rank_id::device_num], captions[rank_id::device_num]
+
+
+def main(args):
+ save_dir = os.path.abspath(args.output_path)
+ os.makedirs(save_dir, exist_ok=True)
+ set_logger(name="", output_dir=save_dir)
+
+ _, rank_id, device_num = init_train_env(**args.env) # TODO: rename as train and infer are identical?
+
+ paths, captions = prepare_captions(args.prompts_file, args.output_path, args.column_names, rank_id, device_num)
+
+ # model initiate and weight loading
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.model_name, local_files_only=True, clean_up_tokenization_spaces=False
+ )
+ model = T5EncoderModel.from_pretrained(
+ args.model_name, mindspore_dtype=MODEL_DTYPE[args.dtype.lower()], local_files_only=True
+ ).set_train(False)
+
+ info = (
+ f"Model name: {args.model_name}\nPrecision: {args.dtype}\nEmbedded sequence length: {args.model_max_length}"
+ f"\nNumber of devices: {device_num}\nRank ID: {rank_id}\nNumber of captions: {len(captions)}"
+ )
+ logger.info(info)
+
+ for i in trange(0, len(captions), args.batch_size):
+ batch = captions[i : i + args.batch_size]
+ inputs = tokenizer(
+ batch,
+ max_length=args.model_max_length,
+ padding="max_length",
+ return_attention_mask=True,
+ truncation=True,
+ return_tensors="np",
+ )
+ tokens = inputs.input_ids
+ masks = inputs.attention_mask
+ output = model(ms.Tensor(inputs.input_ids, dtype=ms.int32), ms.Tensor(inputs.attention_mask, dtype=ms.uint8))[0]
+ output = to_numpy(output).astype(np.float32)
+
+ for j in range(len(output)):
+ paths[i + j].parent.mkdir(parents=True, exist_ok=True)
+ with open(os.path.join(save_dir, paths[i + j]), "wb") as f:
+ np.savez(f, mask=masks[j], text_emb=output[j], tokens=tokens[j])
+
+ logger.info(f"Finished. Embeddings saved to {save_dir}")
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser(description="Text embeddings generation script.")
+ parser.add_function_arguments(init_train_env, "env")
+ parser.add_argument("--model_name", type=str, default="google/byt5-small", help="Text encoder model name.")
+ parser.add_argument(
+ "--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Text encoder model precision."
+ )
+ parser.add_function_arguments(prepare_captions, as_group=False, skip={"rank_id", "device_num"})
+ parser.add_argument("--batch_size", default=10, type=int, help="Inference batch size.")
+ parser.add_argument("--model_max_length", type=int, default=300, help="Model's embedded sequence length.")
+ cfg = parser.parse_args()
+ main(cfg)
diff --git a/examples/moviegen/scripts/moviegen/30B_stage2_train.sh b/examples/moviegen/scripts/moviegen/30B_stage2_train.sh
new file mode 100644
index 0000000000..73e04943e1
--- /dev/null
+++ b/examples/moviegen/scripts/moviegen/30B_stage2_train.sh
@@ -0,0 +1,30 @@
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+# plot memory usage, feature/model: 1
+export MS_MEMORY_STATISTIC=0
+
+# operation/graph fusion for dynamic shape
+# export MS_DEV_ENABLE_KERNEL_PACKET=on # TODO: add dynamic shape support
+
+# log level
+export GLOG_v=2
+
+output_dir=output/stage2_t2iv_256px/$(date +"%Y.%m.%d-%H.%M.%S")
+
+msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \
+python scripts/train.py \
+ --config configs/train/stage2_t2iv_256px.yaml \
+ --env.mode 0 \
+ --env.jit_level O1 \
+ --env.max_device_memory 59GB \
+ --env.distributed True \
+ --model.name=llama-30B \
+ --train.settings.zero_stage 3 \
+ --train.sequence_parallel.shards 8 \
+ --dataset.csv_path CSV_PATH \
+ --dataset.video_folder VIDEO_FOLDER \
+ --dataset.tae_latent_folder TAE_LATENT_FOLDER \
+ --dataset.text_emb_folder.ul2 UL2_FOLDER \
+ --dataset.text_emb_folder.byt5 BYT5_FOLDER \
+ --dataloader.batch_size 1 \
+ --train.ema "" \
+ --train.output_path "$output_dir"
diff --git a/examples/moviegen/scripts/moviegen/stage1_train.sh b/examples/moviegen/scripts/moviegen/stage1_train.sh
new file mode 100644
index 0000000000..d8c60d8871
--- /dev/null
+++ b/examples/moviegen/scripts/moviegen/stage1_train.sh
@@ -0,0 +1,23 @@
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+# plot memory usage, feature/model: 1
+export MS_MEMORY_STATISTIC=0
+
+# log level
+export GLOG_v=2
+
+output_dir=output/stage1_t2i_256px/$(date +"%Y.%m.%d-%H.%M.%S")
+
+msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \
+python scripts/train.py \
+ --config configs/train/stage1_t2i_256px.yaml \
+ --env.mode 0 \
+ --env.jit_level O1 \
+ --env.max_device_memory 59GB \
+ --env.distributed True \
+ --train.settings.zero_stage 2 \
+ --dataset.csv_path CSV_PATH \
+ --dataset.video_folder VIDEO_FOLDER \
+ --dataset.text_emb_folder.ul2 UL2_FOLDER \
+ --dataset.text_emb_folder.byt5 BYT5_FOLDER \
+ --train.ema "" \
+ --train.output_path "$output_dir"
diff --git a/examples/moviegen/scripts/moviegen/stage2_train.sh b/examples/moviegen/scripts/moviegen/stage2_train.sh
new file mode 100644
index 0000000000..0a59a72593
--- /dev/null
+++ b/examples/moviegen/scripts/moviegen/stage2_train.sh
@@ -0,0 +1,26 @@
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+# plot memory usage, feature/model: 1
+export MS_MEMORY_STATISTIC=0
+
+# operation/graph fusion for dynamic shape
+export MS_DEV_ENABLE_KERNEL_PACKET=on
+
+# log level
+export GLOG_v=2
+
+output_dir=output/stage2_t2iv_256px/$(date +"%Y.%m.%d-%H.%M.%S")
+
+msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \
+python scripts/train.py \
+ --config configs/train/stage2_t2iv_256px.yaml \
+ --env.mode 0 \
+ --env.jit_level O1 \
+ --env.max_device_memory 59GB \
+ --env.distributed True \
+ --train.settings.zero_stage 2 \
+ --dataset.csv_path CSV_PATH \
+ --dataset.video_folder VIDEO_FOLDER \
+ --dataset.text_emb_folder.ul2 UL2_FOLDER \
+ --dataset.text_emb_folder.byt5 BYT5_FOLDER \
+ --train.ema "" \
+ --train.output_path "$output_dir"
diff --git a/examples/moviegen/scripts/moviegen/stage3_train.sh b/examples/moviegen/scripts/moviegen/stage3_train.sh
new file mode 100644
index 0000000000..2f8ea12d8b
--- /dev/null
+++ b/examples/moviegen/scripts/moviegen/stage3_train.sh
@@ -0,0 +1,26 @@
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+# plot memory usage, feature/model: 1
+export MS_MEMORY_STATISTIC=0
+
+# operation/graph fusion for dynamic shape
+export MS_DEV_ENABLE_KERNEL_PACKET=on
+
+# log level
+export GLOG_v=2
+
+output_dir=output/stage3_t2iv_768px/$(date +"%Y.%m.%d-%H.%M.%S")
+
+msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \
+python scripts/train.py \
+ --config configs/train/stage3_t2iv_768px.yaml \
+ --env.mode 0 \
+ --env.jit_level O1 \
+ --env.max_device_memory 59GB \
+ --env.distributed True \
+ --train.settings.zero_stage 2 \
+ --dataset.csv_path CSV_PATH \
+ --dataset.video_folder VIDEO_FOLDER \
+ --dataset.text_emb_folder.ul2 UL2_FOLDER \
+ --dataset.text_emb_folder.byt5 BYT5_FOLDER \
+ --train.ema "" \
+ --train.output_path "$output_dir"
diff --git a/examples/moviegen/scripts/run/run_eval_tae_ucf101.sh b/examples/moviegen/scripts/run/run_eval_tae_ucf101.sh
deleted file mode 100755
index c04ded5f54..0000000000
--- a/examples/moviegen/scripts/run/run_eval_tae_ucf101.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-python scripts/inference_vae.py \
---mode=0 \
---jit_level=O1 \
---num_frames=32 \
---batch_size 1 \
---image_size 256 \
---ckpt_path outputs/train_tae/ckpt/tae-e25.ckpt \
---csv_path datasets/ucf101_test.csv \
---video_folder datasets/UCF-101 \
diff --git a/examples/moviegen/scripts/tae/run_eval_tae_ucf101.sh b/examples/moviegen/scripts/tae/run_eval_tae_ucf101.sh
new file mode 100644
index 0000000000..8d14128121
--- /dev/null
+++ b/examples/moviegen/scripts/tae/run_eval_tae_ucf101.sh
@@ -0,0 +1,9 @@
+python scripts/eval_tae.py \
+--mode=0 \
+--jit_level=O1 \
+--sample_n_frames=32 \
+--batch_size 1 \
+--size 256 \
+--pretrained outputs/train_tae/ckpt/tae-e25.ckpt \
+--csv_path datasets/ucf101_test.csv \
+--folder datasets/UCF-101 \
diff --git a/examples/moviegen/scripts/run/run_train_tae.sh b/examples/moviegen/scripts/tae/run_train_tae.sh
old mode 100755
new mode 100644
similarity index 94%
rename from examples/moviegen/scripts/run/run_train_tae.sh
rename to examples/moviegen/scripts/tae/run_train_tae.sh
index 9cc46d6f3e..b0d01e0f9c
--- a/examples/moviegen/scripts/run/run_train_tae.sh
+++ b/examples/moviegen/scripts/tae/run_train_tae.sh
@@ -16,6 +16,6 @@ python scripts/train_tae.py \
--config configs/tae/train/mixed_256x256x32.yaml \
--use_outlier_penalty_loss False \
--csv_path datasets/ucf101_train.csv \
---video_folder datasets/UCF-101 \
+--folder datasets/UCF-101 \
--output_path=$output_dir \
--epochs=100 --ckpt_save_interval=5 \
diff --git a/examples/moviegen/scripts/train.py b/examples/moviegen/scripts/train.py
new file mode 100644
index 0000000000..01f30dae2f
--- /dev/null
+++ b/examples/moviegen/scripts/train.py
@@ -0,0 +1,351 @@
+import logging
+import os
+import sys
+from typing import Dict, Tuple, Union
+
+from jsonargparse import ActionConfigFile, ArgumentParser
+from jsonargparse.typing import path_type
+
+import mindspore.dataset as ds
+from mindspore import GRAPH_MODE, Model, Symbol, Tensor, amp
+from mindspore import dtype as mstype
+from mindspore import get_context, nn, set_context, set_seed
+
+# TODO: remove in future when mindone is ready for install
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../"))
+sys.path.append(mindone_lib_path)
+sys.path.append(os.path.join(__dir__, ".."))
+
+from mg.acceleration import create_parallel_group
+from mg.dataset import ImageVideoDataset, bucket_split_function
+from mg.models.tae import TemporalAutoencoder
+from mg.pipelines import DiffusionWithLoss
+from mg.schedulers import RFlowEvalLoss, RFlowLossWrapper
+from mg.utils import EMA, init_model, resume_train_net
+from mg.utils.callbacks import PerfRecorderCallback, ReduceLROnPlateauByStep, ValidationCallback
+
+from mindone.data import create_dataloader
+from mindone.trainers import create_optimizer, create_scheduler
+from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, StopAtStepCallback
+from mindone.trainers.zero import prepare_train_network
+from mindone.utils import count_params, init_train_env, set_logger
+
+logger = logging.getLogger(__name__)
+
+
+def initialize_dataset(
+ dataset_args, dataloader_args, device_num: int, shard_rank_id: int
+) -> Tuple[Union[ds.BatchDataset, ds.BucketBatchByLengthDataset], int]:
+ dataset = ImageVideoDataset(**dataset_args)
+ transforms = (
+ dataset.train_transforms(dataset_args.target_size) if not dataset_args.apply_transforms_dataset else None
+ )
+
+ logger.info(f"Initializing the dataloader: assigning shard ID {shard_rank_id} out of {device_num} total shards.")
+ dataloader_args = dataloader_args.as_dict()
+ batch_size = dataloader_args.pop("batch_size")
+ dataloader = create_dataloader(
+ dataset,
+ batch_size=batch_size if isinstance(batch_size, int) else 0, # Turn off batching if using buckets
+ transforms=transforms,
+ device_num=device_num,
+ rank_id=shard_rank_id,
+ **dataloader_args,
+ )
+ if isinstance(batch_size, dict): # if buckets are used
+ hash_func, bucket_boundaries, bucket_batch_sizes = bucket_split_function(**batch_size)
+ dataloader = dataloader.bucket_batch_by_length(
+ ["video"],
+ bucket_boundaries,
+ bucket_batch_sizes,
+ element_length_function=hash_func,
+ drop_remainder=dataloader_args["drop_remainder"],
+ )
+ return dataloader, len(dataset)
+
+
+def main(args):
+ # 1. init env
+ args.train.output_path = os.path.abspath(args.train.output_path)
+ os.makedirs(args.train.output_path, exist_ok=True)
+ device_id, rank_id, device_num = init_train_env(**args.env)
+ mode = get_context("mode") # `init_train_env()` may change the mode during debugging
+
+ # if bucketing is used in Graph mode, activate dynamic mode
+ if mode == GRAPH_MODE and isinstance(args.dataloader.batch_size, dict):
+ set_context(graph_kernel_flags="--disable_packet_ops=Reshape")
+
+ # 1.1 init model parallel
+ shard_rank_id = rank_id
+ if (shards := args.train.sequence_parallel.shards) > 1:
+ create_parallel_group(**args.train.sequence_parallel)
+ device_num = device_num // shards
+ shard_rank_id = rank_id // shards
+
+ # FIXME: Improve seed setting
+ set_seed(args.env.seed + shard_rank_id) # set different seeds per NPU for sampling different timesteps
+ ds.set_seed(args.env.seed) # keep MS.dataset's seed consistent as datasets first shuffled and then distributed
+
+ set_logger("", output_dir=args.train.output_path, rank=rank_id)
+
+ # instantiate classes only after initializing training environment
+ initializer = parser.instantiate_classes(cfg)
+
+ # 2. model initialize and weight loading
+ # 2.1 TAE
+ if not args.dataset.tae_latent_folder or (
+ args.valid.dataset and not args.valid.dataset.init_args.tae_latent_folder
+ ):
+ logger.info("Initializing TAE...")
+ tae = TemporalAutoencoder(**args.tae).set_train(False)
+ if tae.dtype != mstype.float32:
+ # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative
+ amp.custom_mixed_precision(
+ tae, black_list=amp.get_black_list() + [nn.GroupNorm, nn.AvgPool2d, nn.Upsample], dtype=tae.dtype
+ )
+ if args.model.in_channels != tae.out_channels:
+ logger.warning(
+ f"The number of model input channels ({args.model.in_channels}) doesn't match the number of TAE output"
+ f" channels ({tae.out_channels}). Setting it to {tae.out_channels}."
+ )
+ args.model.in_channels = tae.out_channels
+ else:
+ logger.info("TAE latent folder provided. Skipping TAE initialization.")
+ tae = None
+
+ # 2.2 Llama 3
+ logger.info("Transformer init")
+ network = init_model(resume=args.train.resume_ckpt is not None, **args.model)
+ # 2.3 LossWrapper
+ rflow_loss_wrapper = RFlowLossWrapper(network)
+
+ # 3. build training network
+ latent_diffusion_with_loss = DiffusionWithLoss(
+ rflow_loss_wrapper, tae, video_emb_cached=bool(args.dataset.tae_latent_folder)
+ )
+
+ # 4. build train & val datasets
+ dataloader, dataset_len = initialize_dataset(args.dataset, args.dataloader, device_num, shard_rank_id)
+
+ eval_diffusion_with_loss, val_dataloader = None, None
+ if args.valid.dataset is not None:
+ val_dataloader, _ = initialize_dataset(
+ args.valid.dataset.init_args, args.valid.dataloader, device_num, shard_rank_id
+ )
+ eval_rflow_loss = RFlowEvalLoss(rflow_loss_wrapper, num_sampling_steps=args.valid.sampling_steps)
+ eval_diffusion_with_loss = DiffusionWithLoss(
+ eval_rflow_loss, tae, video_emb_cached=bool(args.valid.dataset.init_args.tae_latent_folder)
+ )
+
+ # 5. build training utils: lr, optim, callbacks, trainer
+ # 5.1 LR
+ lr = create_scheduler(steps_per_epoch=0, **args.train.lr_scheduler)
+
+ # 5.2 optimizer
+ optimizer = create_optimizer(latent_diffusion_with_loss.trainable_params(), lr=lr, **args.train.optimizer)
+
+ # 5.3 trainer (standalone and distributed)
+ ema = EMA(latent_diffusion_with_loss.network, **args.train.ema.init_args) if args.train.ema else None
+ loss_scaler = initializer.train.loss_scaler
+ net_with_grads = prepare_train_network(
+ latent_diffusion_with_loss, optimizer=optimizer, scale_sense=loss_scaler, ema=ema, **args.train.settings
+ )
+
+ start_epoch, global_step = 0, 0
+ if args.train.resume_ckpt is not None:
+ start_epoch, global_step = resume_train_net(net_with_grads, resume_ckpt=os.path.abspath(args.train.resume_ckpt))
+
+ # TODO: validation graph?
+ # if bucketing is used in Graph mode, activate dynamic inputs
+ if mode == GRAPH_MODE and isinstance(args.dataloader.batch_size, dict):
+ bs = Symbol(unique=True)
+ video = Tensor(shape=[bs, None, args.model.in_channels if tae is None else 3, None, None], dtype=mstype.float32)
+ # FIXME: fix sequence length
+ ul2_emb = Tensor(shape=[bs, 300, 4096], dtype=mstype.float32)
+ byt5_emb = Tensor(shape=[bs, 100, 1472], dtype=mstype.float32)
+ net_with_grads.set_inputs(video, ul2_emb, byt5_emb)
+ logger.info("Dynamic inputs are initialized for bucket config training in Graph mode.")
+
+ model = Model(net_with_grads)
+
+ # 5.4 callbacks
+ callbacks = [OverflowMonitor()]
+ if val_dataloader is not None:
+ callbacks.extend(
+ [
+ ValidationCallback(
+ network=eval_diffusion_with_loss,
+ dataset=val_dataloader,
+ alpha_smooth=0.01, # FIXME
+ valid_frequency=args.valid.frequency,
+ ema=ema,
+ ),
+ ReduceLROnPlateauByStep(optimizer, **args.train.lr_reduce_on_plateau),
+ ]
+ )
+
+ if args.train.settings.zero_stage == 3 or rank_id == 0:
+ ckpt_save_dir = (
+ os.path.join(args.train.output_path, f"rank_{rank_id}/ckpt")
+ if args.train.settings.zero_stage == 3
+ else os.path.join(args.train.output_path, "ckpt")
+ )
+ callbacks.append(
+ EvalSaveCallback(
+ network=latent_diffusion_with_loss.network,
+ model_name=args.model.name,
+ rank_id=0 if args.train.settings.zero_stage == 3 else rank_id, # ZeRO-3 shards across all ranks
+ ckpt_save_dir=ckpt_save_dir,
+ ema=ema,
+ step_mode=True,
+ use_step_unit=True,
+ start_epoch=start_epoch,
+ resume_prefix_blacklist=("tae.", "swap."),
+ train_steps=args.train.steps,
+ **args.train.save,
+ )
+ )
+
+ if rank_id == 0:
+ callbacks.append(
+ PerfRecorderCallback(
+ args.train.output_path, file_name="result_val.log", metric_names=["eval_loss", "eval_loss_smoothed"]
+ )
+ )
+
+ callbacks.append(StopAtStepCallback(train_steps=args.train.steps, global_step=global_step))
+
+ # 5.5 print out key info and save config
+ if rank_id == 0:
+ num_params_tae, num_params_trainable_tae = count_params(tae) if tae is not None else (0, 0)
+ num_params_network, num_params_trainable_network = count_params(network)
+ num_params = num_params_tae + num_params_network
+ num_params_trainable = num_params_trainable_tae + num_params_trainable_network
+ key_info = "Key Settings:\n" + "=" * 50 + "\n"
+ key_info += "\n".join(
+ [
+ f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {mode}",
+ f"Debug mode: {args.env.debug}",
+ f"JIT level: {args.env.jit_level}",
+ f"Distributed mode: {args.env.distributed}",
+ f"Data path: {args.dataset.csv_path}",
+ f"Number of samples: {dataset_len}",
+ f"Model name: {args.model.name}",
+ f"Model dtype: {args.model.dtype}",
+ f"TAE dtype: {args.tae.dtype}",
+ f"Num params: {num_params:,} (network: {num_params_network:,}, tae: {num_params_tae:,})",
+ f"Num trainable params: {num_params_trainable:,}",
+ f"Learning rate: {args.train.lr_scheduler.lr:.0e}",
+ f"Batch size: {args.dataloader.batch_size}",
+ f"Image size: {args.dataset.target_size}",
+ f"Frames: {args.dataset.sample_n_frames}",
+ f"Weight decay: {args.train.optimizer.weight_decay}",
+ f"Grad accumulation steps: {args.train.settings.gradient_accumulation_steps}",
+ f"Number of training steps: {args.train.steps}",
+ f"Loss scaler: {args.train.loss_scaler.class_path}",
+ f"Init loss scale: {args.train.loss_scaler.init_args.loss_scale_value}",
+ f"Grad clipping: {args.train.settings.clip_grad}",
+ f"Max grad norm: {args.train.settings.clip_norm}",
+ f"EMA: {ema is not None}",
+ f"Enable flash attention: {args.model.enable_flash_attention}",
+ ]
+ )
+ key_info += "\n" + "=" * 50
+ print(key_info)
+ parser.save(args, args.train.output_path + "/config.yaml", format="yaml", overwrite=True)
+
+ # 6. train
+ logger.info("Start training...")
+ # train() uses epochs, so the training will be terminated by the StopAtStepCallback
+ model.train(args.train.steps, dataloader, callbacks=callbacks, initial_epoch=start_epoch)
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser(description="Movie Gen training script.")
+ parser.add_argument(
+ "-c",
+ "--config",
+ action=ActionConfigFile,
+ help="Path to load a config yaml file that describes the setting which will override the default arguments.",
+ )
+ parser.add_function_arguments(init_train_env, "env")
+ parser.add_function_arguments(init_model, "model", skip={"resume"})
+ parser.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False)
+ parser.add_class_arguments(
+ ImageVideoDataset, "dataset", skip={"frames_mask_generator", "t_compress_func"}, instantiate=False
+ )
+ parser.add_function_arguments(
+ create_dataloader,
+ "dataloader",
+ skip={"dataset", "batch_size", "transforms", "batch_transforms", "device_num", "rank_id"},
+ )
+ parser.add_argument( # FIXME: support bucketing
+ "--dataloader.batch_size", default=1, type=Union[int, Dict[str, int]], help="Number of samples per batch"
+ )
+ parser.link_arguments("env.debug", "dataloader.debug", apply_on="parse")
+ parser.add_function_arguments(create_parallel_group, "train.sequence_parallel")
+ parser.add_function_arguments(create_scheduler, "train.lr_scheduler", skip={"steps_per_epoch", "num_epochs"})
+ parser.add_class_arguments(
+ ReduceLROnPlateauByStep, "train.lr_reduce_on_plateau", skip={"optimizer"}, instantiate=False
+ )
+ parser.add_function_arguments(create_optimizer, "train.optimizer", skip={"params", "lr"})
+ parser.add_subclass_arguments(
+ nn.Cell,
+ "train.loss_scaler",
+ fail_untyped=False, # no typing in mindspore
+ help="mindspore.nn.FixedLossScaleUpdateCell or mindspore.nn.DynamicLossScaleUpdateCell",
+ )
+ parser.add_function_arguments(
+ prepare_train_network, "train.settings", skip={"network", "optimizer", "scale_sense", "ema"}
+ )
+ parser.add_subclass_arguments(EMA, "train.ema", skip={"network"}, required=False, instantiate=False)
+ parser.add_function_arguments(resume_train_net, "train", skip={"train_net"})
+ parser.add_argument(
+ "--train.output_path",
+ default="output/",
+ type=path_type("dcc"), # path to a directory that can be created if it does not exist
+ help="Output directory to save training results.",
+ )
+ parser.add_argument("--train.steps", default=100, type=int, help="Number of steps to train. Default: 100.")
+ parser.link_arguments("train.steps", "train.lr_scheduler.total_steps", apply_on="parse")
+ parser.add_class_arguments(
+ EvalSaveCallback,
+ "train.save",
+ skip={
+ "network",
+ "rank_id",
+ "shard_rank_id",
+ "ckpt_save_dir",
+ "output_dir",
+ "ema",
+ "start_epoch",
+ "model_name",
+ "step_mode",
+ "use_step_unit",
+ "train_steps",
+ "resume_prefix_blacklist",
+ },
+ instantiate=False,
+ )
+
+ # validation
+ val_group = parser.add_argument_group("Validation")
+ val_group.add_argument(
+ "valid.sampling_steps", type=int, default=10, help="Number of sampling steps for validation."
+ )
+ val_group.add_argument("valid.frequency", type=int, default=1, help="Frequency of validation in steps.")
+ val_group.add_subclass_arguments(
+ ImageVideoDataset,
+ "valid.dataset",
+ skip={"frames_mask_generator", "t_compress_func"},
+ instantiate=False,
+ required=False,
+ )
+ val_group.add_function_arguments(
+ create_dataloader, "valid.dataloader", skip={"dataset", "transforms", "device_num", "rank_id"}
+ )
+ parser.link_arguments("env.debug", "valid.dataloader.debug", apply_on="parse")
+
+ cfg = parser.parse_args()
+ main(cfg)
diff --git a/examples/moviegen/scripts/train_tae.py b/examples/moviegen/scripts/train_tae.py
index 529c95a617..00a03f113b 100644
--- a/examples/moviegen/scripts/train_tae.py
+++ b/examples/moviegen/scripts/train_tae.py
@@ -3,13 +3,11 @@
import shutil
import sys
import time
-from typing import Tuple
import yaml
import mindspore as ms
-from mindspore import Model, nn
-from mindspore.communication.management import get_group_size, get_rank, init
+from mindspore import Model, amp, nn
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import TimeMonitor
@@ -19,21 +17,21 @@
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
from args_train_tae import parse_args
-from mg.datasets.tae_dataset import create_dataloader
+from mg.dataset.tae_dataset import BatchTransform, VideoDataset
+from mg.models.tae import TemporalAutoencoder
from mg.models.tae.losses import GeneratorWithLoss
from mg.models.tae.modules import SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample
-from mg.models.tae.tae import TemporalAutoencoder
+from mindone.data import create_dataloader
from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback
from mindone.trainers.checkpoint import CheckpointManager, resume_train_network
from mindone.trainers.ema import EMA
from mindone.trainers.lr_schedule import create_scheduler
from mindone.trainers.optim import create_optimizer
from mindone.trainers.train_step import TrainOneStepWrapper
-from mindone.utils.amp import auto_mixed_precision
+from mindone.utils import init_train_env
from mindone.utils.logger import set_logger
from mindone.utils.params import count_params
-from mindone.utils.seed import set_random_seed
os.environ["HCCL_CONNECT_TIMEOUT"] = "6000"
os.environ["MS_ASCEND_CHECK_OVERFLOW_MODE"] = "INFNAN_MODE"
@@ -54,99 +52,15 @@ def create_loss_scaler(loss_scaler_type, init_loss_scale, loss_scale_factor=2, s
return loss_scaler
-def init_env(
- mode: int = ms.GRAPH_MODE,
- seed: int = 42,
- distributed: bool = False,
- max_device_memory: str = None,
- device_target: str = "Ascend",
- parallel_mode: str = "data",
- jit_level: str = "O2",
- global_bf16: bool = False,
- debug: bool = False,
-) -> Tuple[int, int]:
- """
- Initialize MindSpore environment.
-
- Args:
- mode: MindSpore execution mode. Default is 0 (ms.GRAPH_MODE).
- seed: The seed value for reproducibility. Default is 42.
- distributed: Whether to enable distributed training. Default is False.
- Returns:
- A tuple containing the device ID, rank ID and number of devices.
- """
- set_random_seed(seed)
-
- if debug and mode == ms.GRAPH_MODE: # force PyNative mode when debugging
- logger.warning("Debug mode is on, switching execution mode to PyNative.")
- mode = ms.PYNATIVE_MODE
-
- if max_device_memory is not None:
- ms.set_context(max_device_memory=max_device_memory)
-
- # ms.set_context(mempool_block_size="55GB")
- # ms.set_context(pynative_synchronize=True)
- if distributed:
- ms.set_context(
- mode=mode,
- device_target=device_target,
- )
- if parallel_mode == "optim":
- print("use optim parallel")
- ms.set_auto_parallel_context(
- parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL,
- enable_parallel_optimizer=True,
- )
- init()
- device_num = get_group_size()
- rank_id = get_rank()
- else:
- init()
- device_num = get_group_size()
- rank_id = get_rank()
- logger.debug(f"rank_id: {rank_id}, device_num: {device_num}")
- ms.reset_auto_parallel_context()
-
- ms.set_auto_parallel_context(
- parallel_mode=ms.ParallelMode.DATA_PARALLEL,
- gradients_mean=True,
- device_num=device_num,
- )
-
- var_info = ["device_num", "rank_id", "device_num / 8", "rank_id / 8"]
- var_value = [device_num, rank_id, int(device_num / 8), int(rank_id / 8)]
- logger.info(dict(zip(var_info, var_value)))
-
- else:
- device_num = 1
- rank_id = 0
- ms.set_context(
- mode=mode,
- device_target=device_target,
- pynative_synchronize=debug,
- )
-
- if mode == 0:
- ms.set_context(jit_config={"jit_level": jit_level})
-
- if global_bf16:
- # only effective in GE mode, i.e. jit_level: O2
- ms.set_context(ascend_config={"precision_mode": "allow_mix_precision_bf16"})
-
- return rank_id, device_num
-
-
def main(args):
# 1. init
- rank_id, device_num = init_env(
- args.mode,
+ _, rank_id, device_num = init_train_env(
+ mode=args.mode,
seed=args.seed,
- distributed=args.use_parallel,
+ distributed=args.distributed,
device_target=args.device_target,
max_device_memory=args.max_device_memory,
- parallel_mode=args.parallel_mode,
jit_level=args.jit_level,
- global_bf16=args.global_bf16,
debug=args.debug,
)
set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level))
@@ -159,23 +73,25 @@ def main(args):
assert args.image_size[0] == args.image_size[1], "Currently only h==w is supported"
image_size = args.image_size[0]
- ds_config = dict(
+ dataset = VideoDataset(
csv_path=args.csv_path,
- data_folder=args.video_folder,
- size=image_size,
+ folder=args.folder,
+ size=args.image_size,
crop_size=args.crop_size,
- sample_n_frames=args.num_frames,
- sample_stride=args.frame_stride,
+ sample_n_frames=args.sample_n_frames,
+ sample_stride=args.sample_stride,
video_column=args.video_column,
random_crop=args.random_crop,
flip=args.flip,
+ output_columns=["video"],
)
+ transform = BatchTransform(mixed_strategy=args.mixed_strategy, mixed_image_ratio=args.mixed_image_ratio)
+ transform = {"operations": transform, "input_columns": ["video"]}
dataloader = create_dataloader(
- ds_config,
- args.batch_size,
- mixed_strategy=args.mixed_strategy,
- mixed_image_ratio=args.mixed_image_ratio,
- num_parallel_workers=args.num_parallel_workers,
+ dataset=dataset,
+ batch_size=args.batch_size,
+ batch_transforms=transform,
+ num_workers=args.num_workers,
max_rowsize=256,
shuffle=True,
device_num=device_num,
@@ -187,26 +103,24 @@ def main(args):
# 3. build models
ae = TemporalAutoencoder(
- pretrained=args.pretrained_model_path,
+ pretrained=args.pretrained,
use_recompute=args.use_recompute,
)
- if args.use_discriminator:
- logging.error("Discriminator is not used or supported in OpenSora v1.2")
-
# mixed precision
# TODO: set softmax, sigmoid computed in FP32. manually set inside network since they are ops, instead of layers whose precision will be set by AMP level.
- if args.dtype in ["fp16", "bf16"]:
+ if args.dtype != "fp32":
dtype = {"fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype]
# TODO: check ResizeNearest bf16 support for ms>2.3.1
- ae = auto_mixed_precision(
+ ae = amp.custom_mixed_precision(
ae,
- args.amp_level,
- dtype,
- custom_fp32_cells=[SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample]
- if args.vae_keep_updown_fp32
- else [] + ([nn.GroupNorm] if args.vae_keep_gn_fp32 else []),
- # custom_fp32_cells=[nn.GroupNorm, SpatialUpsample] if args.vae_keep_gn_fp32 else [SpatialUpsample],
+ black_list=amp.get_black_list()
+ + (
+ [SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample]
+ if args.vae_keep_updown_fp32
+ else [] + ([nn.GroupNorm] if args.vae_keep_gn_fp32 else [])
+ ),
+ dtype=dtype,
)
# 4. build net with loss
diff --git a/examples/moviegen/tests/parallel/run_test_llama_sequence_parallel.sh b/examples/moviegen/tests/parallel/run_test_llama_sequence_parallel.sh
new file mode 100755
index 0000000000..e86f871cd8
--- /dev/null
+++ b/examples/moviegen/tests/parallel/run_test_llama_sequence_parallel.sh
@@ -0,0 +1,18 @@
+#!/bin/sh
+set -e
+
+SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
+PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))"
+EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")"
+PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")"
+
+export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}"
+
+echo "******** Graph Mode ********"
+msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir="./log_test_sp_graph" --join True ${SCRIPT_DIR}/test_llama_sequence_parallel.py --mode 0
+echo "Done. Check the log at './log_test_sp_graph'."
+echo "========================================================================="
+
+echo "******** Pynative Mode ********"
+msrun --master_port=1235 --worker_num=2 --local_worker_num=2 --log_dir="./log_test_sp_pynative" --join True ${SCRIPT_DIR}/test_llama_sequence_parallel.py --mode 1
+echo "Done. Check the log at './log_test_sp_pynative'."
diff --git a/examples/moviegen/tests/parallel/test_llama_sequence_parallel.py b/examples/moviegen/tests/parallel/test_llama_sequence_parallel.py
new file mode 100644
index 0000000000..1cb753a73f
--- /dev/null
+++ b/examples/moviegen/tests/parallel/test_llama_sequence_parallel.py
@@ -0,0 +1,110 @@
+import argparse
+from typing import Tuple
+
+import numpy as np
+from mg.acceleration import create_parallel_group, get_sequence_parallel_group
+from mg.models.llama.network import LlamaModel
+
+import mindspore as ms
+import mindspore.nn as nn
+import mindspore.ops as ops
+from mindspore import Tensor
+from mindspore.communication import get_group_size, init
+
+
+class MeanNet(nn.Cell):
+ def __init__(self, net: nn.Cell) -> None:
+ super().__init__()
+ self.net = net
+
+ def construct(self, *inputs):
+ output = self.net(*inputs)
+ return output.mean() * 1024.0
+
+
+def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, ...]:
+ latent_embedding = ops.rand([1, 16, 8, 24, 44], dtype=dtype)
+ timestep = ms.Tensor([35], dtype=ms.int64)
+ ul2_emb = ops.rand([1, 64, 4096], dtype=dtype)
+ metaclip_emb = ops.rand([1, 64, 1280], dtype=dtype)
+ byt5_emb = ops.rand([1, 64, 1472], dtype=dtype)
+ return latent_embedding, timestep, ul2_emb, metaclip_emb, byt5_emb
+
+
+def get_network_config():
+ config = dict(num_hidden_layers=1, attn_implementation="eager", post_init_weight=False)
+ return config
+
+
+def run_network(mode: int = 0, dtype: ms.Type = ms.float32):
+ ms.set_context(mode=mode)
+ init()
+
+ # prepare data
+ ms.set_seed(1024)
+ data = get_sample_data(dtype=dtype)
+
+ run_parallel_network(data, dtype=dtype)
+
+
+def run_parallel_network(data: Tuple[Tensor, ...], dtype: ms.Type = ms.float32):
+ # non parallel network
+ ms.set_seed(1024)
+ non_parallel_network_cfg = get_network_config()
+ non_parallel_network = LlamaModel(**non_parallel_network_cfg, dtype=dtype)
+
+ # parallel netowrk
+ ms.set_seed(1024)
+ create_parallel_group(shards=get_group_size())
+ parallel_network_cfg = get_network_config()
+ parallel_network = LlamaModel(**parallel_network_cfg, dtype=dtype)
+
+ # load weight
+ for (_, w0), (_, w1) in zip(non_parallel_network.parameters_and_names(), parallel_network.parameters_and_names()):
+ w1.set_data(w0) # FIXME: seed does not work
+ np.testing.assert_allclose(w0.value().asnumpy(), w1.value().asnumpy())
+
+ # test forward
+ non_parallel_out = non_parallel_network(*data).asnumpy()
+ parallel_out = parallel_network(*data).asnumpy()
+
+ assert np.count_nonzero(non_parallel_out) > 0
+ np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape)
+ np.testing.assert_allclose(non_parallel_out, parallel_out, rtol=1.3e-6, atol=1e-5)
+ print("Test 1 (Forward): Passed.", flush=True)
+
+ # test backward
+ non_parallel_mean_net = MeanNet(non_parallel_network)
+ parallel_mean_net = MeanNet(parallel_network)
+
+ # check the parameter gradient
+ grad_fn = ms.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params())
+ non_parallel_grads = grad_fn(*data)
+
+ grad_fn = ms.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params())
+ parallel_grads = grad_fn(*data)
+
+ # take mean around different ranks
+ sp_group = get_sequence_parallel_group()
+ reduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=sp_group)
+ num = get_group_size()
+ syn_parallel_grads = list()
+ for x in parallel_grads:
+ syn_parallel_grads.append(reduce(x) / num)
+
+ pass_grads = []
+ for grad_0, grad_1 in zip(non_parallel_grads, syn_parallel_grads):
+ is_passed = np.allclose(grad_0.asnumpy(), grad_1.asnumpy(), rtol=1.3e-6, atol=1e-5)
+ pass_grads.append(is_passed)
+ assert all(pass_grads), f"Pass rate ({sum(pass_grads)/len(pass_grads) * 100:.3f} %) is not 100 %"
+
+ print("Test 2 (Backward: Parameter Gradient): Passed.", flush=True)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)"
+ )
+ args = parser.parse_args()
+ run_network(mode=args.mode)
diff --git a/examples/moviegen/tests/ut/test_byt5_pynative.py b/examples/moviegen/tests/ut/test_byt5_pynative.py
new file mode 100644
index 0000000000..79cca8d623
--- /dev/null
+++ b/examples/moviegen/tests/ut/test_byt5_pynative.py
@@ -0,0 +1,85 @@
+import os
+import sys
+
+import numpy as np
+import pytest
+import torch
+from transformers import AutoTokenizer
+from transformers import T5EncoderModel as T5EncoderModel_PyTorch
+
+import mindspore as ms
+
+# FIXME: remove in future when mindone is ready for install
+sys.path.append(os.path.join(os.path.dirname(__file__), "../../../.."))
+from mindone.transformers.models.t5.modeling_t5 import T5EncoderModel, T5LayerNorm
+
+ms.set_context(mode=ms.PYNATIVE_MODE)
+
+fp32_tolerance = 1e-4
+fp16_tolerance = 2e-2
+bf16_tolerance = 2e-1
+
+test_samples = [
+ "Life is like a box of chocolates.",
+ "La vie est comme une boรฎte de chocolat.",
+ "Today is Monday.",
+ "Aujourd'hui c'est lundi.",
+]
+
+tokenizer = AutoTokenizer.from_pretrained("google/byt5-small", local_files_only=True)
+test_samples = tokenizer(test_samples, padding="longest", return_tensors="np")
+
+
+@pytest.fixture(scope="function")
+def byt5_pt():
+ return T5EncoderModel_PyTorch.from_pretrained("google/byt5-small", local_files_only=True)
+
+
+@pytest.fixture(scope="function")
+def byt5_ms():
+ return T5EncoderModel.from_pretrained("google/byt5-small", local_files_only=True)
+
+
+def test_fp32(byt5_ms, byt5_pt):
+ # set models precision
+ byt5_pt.to(torch.float32)
+
+ ms_enc = byt5_ms(
+ ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8)
+ )
+ ms_enc = ms_enc[0].asnumpy().astype(np.float32)
+ pt_enc = byt5_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False)
+ pt_enc = pt_enc[0].detach().numpy().astype(np.float32)
+ assert np.allclose(ms_enc, pt_enc, atol=fp32_tolerance, rtol=0)
+
+
+def test_fp16(byt5_ms, byt5_pt):
+ # set models precision
+ byt5_ms = ms.amp.custom_mixed_precision(
+ byt5_ms, black_list=ms.amp.get_black_list() + [T5LayerNorm], dtype=ms.float16
+ )
+ byt5_pt.to(torch.float16)
+
+ ms_enc = byt5_ms(
+ ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8)
+ )
+ ms_enc = ms_enc[0].asnumpy().astype(np.float32)
+ pt_enc = byt5_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False)
+ pt_enc = pt_enc[0].detach().numpy().astype(np.float32)
+ assert np.allclose(ms_enc, pt_enc, atol=fp16_tolerance, rtol=0)
+
+
+def test_bf16(byt5_ms, byt5_pt):
+ # set models precision
+ byt5_ms = ms.amp.custom_mixed_precision(
+ byt5_ms, black_list=ms.amp.get_black_list() + [T5LayerNorm], dtype=ms.bfloat16
+ )
+ byt5_pt.to(torch.bfloat16)
+
+ ms_enc = byt5_ms(
+ ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8)
+ )
+ ms_enc = ms_enc[0].astype(ms.float32).asnumpy()
+ pt_enc = byt5_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False)
+ pt_enc = pt_enc[0].detach().to(torch.float32).numpy()
+ assert np.allclose(ms_enc, pt_enc, atol=bf16_tolerance, rtol=0)
diff --git a/examples/moviegen/tests/ut/test_llama3_forward.py b/examples/moviegen/tests/ut/test_llama3_forward.py
new file mode 100644
index 0000000000..260e65cb04
--- /dev/null
+++ b/examples/moviegen/tests/ut/test_llama3_forward.py
@@ -0,0 +1,16 @@
+import numpy as np
+from mg import llama3_1B
+
+import mindspore as ms
+
+
+def test_llama3_forward_graph():
+ ms.set_context(mode=ms.GRAPH_MODE)
+ network = llama3_1B(attn_implementation="flash_attention", dtype=ms.bfloat16)
+
+ latent_embedding = ms.Tensor(np.ones((1, 16, 8, 24, 44)), dtype=ms.bfloat16)
+ timestep = ms.Tensor([35], dtype=ms.int64)
+ text_embedding = ms.Tensor(np.ones((1, 64, 4096)), dtype=ms.bfloat16)
+ outputs = network(latent_embedding, timestep, text_embedding)
+
+ assert outputs.shape == (1, 16, 8, 24, 44)
diff --git a/examples/moviegen/tests/ut/test_rflow.py b/examples/moviegen/tests/ut/test_rflow.py
new file mode 100644
index 0000000000..7bd3c1c8c0
--- /dev/null
+++ b/examples/moviegen/tests/ut/test_rflow.py
@@ -0,0 +1,27 @@
+import numpy as np
+from mg.schedulers import RFlowLossWrapper
+
+import mindspore as ms
+import mindspore.nn as nn
+from mindspore import Tensor
+
+
+class SimpleBF16Net(nn.Cell):
+ def construct(self, x: Tensor, timestamp: Tensor, text_embedding: Tensor):
+ return x.to(ms.bfloat16)
+
+ @property
+ def dtype(self):
+ return ms.bfloat16
+
+
+def test_rflow_loss():
+ ms.set_context(mode=ms.GRAPH_MODE)
+ network = RFlowLossWrapper(
+ SimpleBF16Net(), num_timesteps=1000, sample_method="logit-normal", loc=0.0, scale=1.0, eps=1e-5
+ )
+
+ latent_embedding = ms.Tensor(np.ones((2, 16, 8, 24, 44)), dtype=ms.bfloat16)
+ text_embedding = ms.Tensor(np.ones((2, 64, 4096)), dtype=ms.bfloat16)
+ loss = network(latent_embedding, text_embedding).item()
+ assert loss > 0
diff --git a/examples/moviegen/tests/test_tae.py b/examples/moviegen/tests/ut/test_tae.py
similarity index 96%
rename from examples/moviegen/tests/test_tae.py
rename to examples/moviegen/tests/ut/test_tae.py
index a57019e0b5..8ec7646d1d 100644
--- a/examples/moviegen/tests/test_tae.py
+++ b/examples/moviegen/tests/ut/test_tae.py
@@ -1,9 +1,15 @@
+import os
import sys
import numpy as np
from PIL import Image
-sys.path.insert(0, ".")
+import mindspore as ms
+
+# TODO: remove in future when mindone is ready for install
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../../"))
+sys.path.append(mindone_lib_path)
from mg.models.tae.modules import (
Conv2_5d,
@@ -20,8 +26,6 @@
from mg.models.tae.sd3_vae import SD3d5_VAE
from mg.models.tae.tae import SDXL_CONFIG, TAE_CONFIG, TemporalAutoencoder
-import mindspore as ms
-
def get_input_image(img_path="../videocomposer/demo_video/moon_on_water.jpg", W=128, H=128):
target_size = (H, W)
diff --git a/examples/moviegen/tests/ut/test_transforms.py b/examples/moviegen/tests/ut/test_transforms.py
new file mode 100644
index 0000000000..d988d97bfd
--- /dev/null
+++ b/examples/moviegen/tests/ut/test_transforms.py
@@ -0,0 +1,16 @@
+import numpy as np
+from mg.dataset.transforms import ResizeCrop
+
+
+def test_horizontal_image_crop():
+ image = np.random.randint(0, 256, (150, 250, 3), dtype=np.uint8)
+ rc = ResizeCrop((100, 200))
+ image = rc(image)
+ assert image.shape == (100, 200, 3)
+
+
+def test_vertical_image_crop():
+ image = np.random.randint(0, 256, (250, 150, 3), dtype=np.uint8)
+ rc = ResizeCrop((100, 200))
+ image = rc(image)
+ assert image.shape == (200, 100, 3)
diff --git a/examples/moviegen/tests/ut/test_ul2_pynative.py b/examples/moviegen/tests/ut/test_ul2_pynative.py
new file mode 100644
index 0000000000..6e03e87942
--- /dev/null
+++ b/examples/moviegen/tests/ut/test_ul2_pynative.py
@@ -0,0 +1,83 @@
+import os
+import sys
+
+import numpy as np
+import pytest
+import torch
+from transformers import AutoTokenizer
+from transformers import T5EncoderModel as T5EncoderModel_PyTorch
+
+import mindspore as ms
+
+# FIXME: remove in future when mindone is ready for install
+sys.path.append(os.path.join(os.path.dirname(__file__), "../../../.."))
+from mindone.transformers.models.t5.modeling_t5 import T5EncoderModel, T5LayerNorm
+
+ms.set_context(mode=ms.PYNATIVE_MODE)
+
+fp32_tolerance = 1e-4
+fp16_tolerance = 2e-2
+bf16_tolerance = 2e-1
+
+test_samples = [
+ "Life is like a box of chocolates.",
+ "La vie est comme une boรฎte de chocolat.",
+ "Today is Monday.",
+ "Aujourd'hui c'est lundi.",
+]
+
+tokenizer = AutoTokenizer.from_pretrained("google/ul2", local_files_only=True)
+test_samples = tokenizer(test_samples, padding="max_length", return_tensors="np")
+
+
+@pytest.fixture(scope="function")
+def ul2_pt():
+ return T5EncoderModel_PyTorch.from_pretrained("google/ul2", local_files_only=True)
+
+
+@pytest.fixture(scope="function")
+def ul2_ms():
+ return T5EncoderModel.from_pretrained("google/ul2", local_files_only=True)
+
+
+def test_fp32(ul2_ms, ul2_pt):
+ # set models precision
+ ul2_pt.to(torch.float32)
+
+ ms_enc = ul2_ms(
+ ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8)
+ )
+ ms_enc = ms_enc[0].asnumpy().astype(np.float32)
+ pt_enc = ul2_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False)
+ pt_enc = pt_enc[0].detach().numpy().astype(np.float32)
+ assert np.allclose(ms_enc, pt_enc, atol=fp32_tolerance, rtol=0)
+
+
+def test_fp16(ul2_ms, ul2_pt):
+ # set models precision
+ ul2_ms = ms.amp.custom_mixed_precision(ul2_ms, black_list=ms.amp.get_black_list() + [T5LayerNorm], dtype=ms.float16)
+ ul2_pt.to(torch.float16)
+
+ ms_enc = ul2_ms(
+ ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8)
+ )
+ ms_enc = ms_enc[0].asnumpy().astype(np.float32)
+ pt_enc = ul2_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False)
+ pt_enc = pt_enc[0].detach().numpy().astype(np.float32)
+ assert np.allclose(ms_enc, pt_enc, atol=fp16_tolerance, rtol=0)
+
+
+def test_bf16(ul2_ms, ul2_pt):
+ # set models precision
+ ul2_ms = ms.amp.custom_mixed_precision(
+ ul2_ms, black_list=ms.amp.get_black_list() + [T5LayerNorm], dtype=ms.bfloat16
+ )
+ ul2_pt.to(torch.bfloat16)
+
+ ms_enc = ul2_ms(
+ ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8)
+ )
+ ms_enc = ms_enc[0].astype(ms.float32).asnumpy()
+ pt_enc = ul2_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False)
+ pt_enc = pt_enc[0].detach().to(torch.float32).numpy()
+ assert np.allclose(ms_enc, pt_enc, atol=bf16_tolerance, rtol=0)
diff --git a/examples/moviegen/tools/download_convert_st.py b/examples/moviegen/tools/download_convert_st.py
new file mode 100644
index 0000000000..6e32186e1e
--- /dev/null
+++ b/examples/moviegen/tools/download_convert_st.py
@@ -0,0 +1,340 @@
+"""
+Modified from
+https://github.com/huggingface/safetensors/blob/main/bindings/python/convert.py
+"""
+import argparse
+import json
+import os
+from collections import defaultdict
+from typing import Dict, List, Optional, Set
+
+import requests
+import torch
+from huggingface_hub import HfApi, configure_http_backend, hf_hub_download, snapshot_download
+from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
+
+
+def backend_factory() -> requests.Session:
+ session = requests.Session()
+ session.verify = False
+ return session
+
+
+def _remove_duplicate_names(
+ state_dict: Dict[str, torch.Tensor],
+ *,
+ preferred_names: List[str] = None,
+ discard_names: List[str] = None,
+) -> Dict[str, List[str]]:
+ if preferred_names is None:
+ preferred_names = []
+ preferred_names = set(preferred_names)
+ if discard_names is None:
+ discard_names = []
+ discard_names = set(discard_names)
+
+ shareds = _find_shared_tensors(state_dict)
+ to_remove = defaultdict(list)
+ for shared in shareds:
+ complete_names = set([name for name in shared if _is_complete(state_dict[name])])
+ if not complete_names:
+ if len(shared) == 1:
+ # Force contiguous
+ name = list(shared)[0]
+ state_dict[name] = state_dict[name].clone()
+ complete_names = {name}
+ else:
+ raise RuntimeError(
+ "Error while trying to find names to remove to save state dict, but found no suitable name to keep"
+ f" for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model"
+ " since you could be storing much more memory than needed."
+ " Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
+ )
+
+ keep_name = sorted(list(complete_names))[0]
+
+ # Mechanism to preferentially select keys to keep
+ # coming from the on-disk file to allow
+ # loading models saved with a different choice
+ # of keep_name
+ preferred = complete_names.difference(discard_names)
+ if preferred:
+ keep_name = sorted(list(preferred))[0]
+
+ if preferred_names:
+ preferred = preferred_names.intersection(complete_names)
+ if preferred:
+ keep_name = sorted(list(preferred))[0]
+ for name in sorted(shared):
+ if name != keep_name:
+ to_remove[keep_name].append(name)
+ return to_remove
+
+
+def get_discard_names(
+ model_id: str, revision: Optional[str], folder: str, token: Optional[str], endpoint: str
+) -> List[str]:
+ try:
+ import json
+
+ import transformers
+
+ config_filename = hf_hub_download(
+ model_id, revision=revision, filename="config.json", token=token, cache_dir=folder, endpoint=endpoint
+ )
+ with open(config_filename, "r") as f:
+ config = json.load(f)
+ architecture = config["architectures"][0]
+
+ class_ = getattr(transformers, architecture)
+
+ # Name for this variable depends on transformers version.
+ discard_names = getattr(class_, "_tied_weights_keys", [])
+
+ except Exception:
+ discard_names = []
+ return discard_names
+
+
+class AlreadyExists(Exception):
+ pass
+
+
+def check_file_size(sf_filename: str, pt_filename: str):
+ sf_size = os.stat(sf_filename).st_size
+ pt_size = os.stat(pt_filename).st_size
+
+ if (sf_size - pt_size) / pt_size > 0.01:
+ raise RuntimeError(
+ f"""The file size different is more than 1%:
+ - {sf_filename}: {sf_size}
+ - {pt_filename}: {pt_size}
+ """
+ )
+
+
+def rename(pt_filename: str) -> str:
+ filename, ext = os.path.splitext(pt_filename)
+ local = f"{filename}.safetensors"
+ local = local.replace("pytorch_model", "model")
+ return local
+
+
+def convert_multi(
+ model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str], endpoint: str
+) -> str:
+ filename = hf_hub_download(
+ repo_id=model_id,
+ revision=revision,
+ filename="pytorch_model.bin.index.json",
+ token=token,
+ cache_dir=folder,
+ endpoint=endpoint,
+ )
+ save_path = os.path.dirname(filename)
+ with open(filename, "r") as f:
+ data = json.load(f)
+
+ filenames = set(data["weight_map"].values())
+ for filename in filenames:
+ pt_filename = hf_hub_download(
+ model_id, revision=revision, filename=filename, token=token, cache_dir=folder, endpoint=endpoint
+ )
+ sf_filename = rename(pt_filename)
+ sf_filename = os.path.join(save_path, sf_filename)
+ convert_file(pt_filename, sf_filename, discard_names=discard_names)
+
+ index = os.path.join(save_path, "model.safetensors.index.json")
+ with open(index, "w") as f:
+ newdata = {k: v for k, v in data.items()}
+ newmap = {k: rename(v) for k, v in data["weight_map"].items()}
+ newdata["weight_map"] = newmap
+ json.dump(newdata, f, indent=4)
+
+ return save_path
+
+
+def convert_single(
+ model_id: str,
+ *,
+ revision: Optional[str],
+ folder: str,
+ token: Optional[str],
+ discard_names: List[str],
+ endpoint: str,
+) -> str:
+ pt_filename = hf_hub_download(
+ repo_id=model_id,
+ revision=revision,
+ filename="pytorch_model.bin",
+ token=token,
+ cache_dir=folder,
+ endpoint=endpoint,
+ )
+ save_path = os.path.dirname(pt_filename)
+ sf_name = "model.safetensors"
+ sf_filename = os.path.join(save_path, sf_name)
+ convert_file(pt_filename, sf_filename, discard_names)
+ return save_path
+
+
+def convert_file(
+ pt_filename: str,
+ sf_filename: str,
+ discard_names: List[str],
+):
+ loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
+ if "state_dict" in loaded:
+ loaded = loaded["state_dict"]
+ to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
+
+ metadata = {"format": "pt"}
+ for kept_name, to_remove_group in to_removes.items():
+ for to_remove in to_remove_group:
+ if to_remove not in metadata:
+ metadata[to_remove] = kept_name
+ del loaded[to_remove]
+ # Force tensors to be contiguous
+ loaded = {k: v.contiguous() for k, v in loaded.items()}
+
+ dirname = os.path.dirname(sf_filename)
+ os.makedirs(dirname, exist_ok=True)
+ save_file(loaded, sf_filename, metadata=metadata)
+ check_file_size(sf_filename, pt_filename)
+ reloaded = load_file(sf_filename)
+ for k in loaded:
+ pt_tensor = loaded[k]
+ sf_tensor = reloaded[k]
+ if not torch.equal(pt_tensor, sf_tensor):
+ raise RuntimeError(f"The output tensors do not match for key {k}")
+
+
+def convert_generic(
+ model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str], endpoint: str
+) -> str:
+ save_path = ""
+ extensions = {".bin", ".ckpt"}
+ for filename in filenames:
+ prefix, ext = os.path.splitext(filename)
+ if ext in extensions:
+ pt_filename = hf_hub_download(
+ model_id, revision=revision, filename=filename, token=token, cache_dir=folder, endpoint=endpoint
+ )
+ save_path = os.path.dirname(pt_filename)
+
+ dirname, raw_filename = os.path.split(filename)
+ if raw_filename == "pytorch_model.bin":
+ # XXX: This is a special case to handle `transformers` and the
+ # `transformers` part of the model which is actually loaded by `transformers`.
+ sf_in_repo = os.path.join(dirname, "model.safetensors")
+ else:
+ sf_in_repo = f"{prefix}.safetensors"
+ sf_filename = os.path.join(save_path, sf_in_repo)
+ convert_file(pt_filename, sf_filename, discard_names=[])
+ return save_path
+
+
+def convert(
+ model_id: str,
+ revision: Optional[str] = None,
+ folder: str = None,
+ force: bool = False,
+ endpoint: str = "https://hf-mirror.com",
+) -> str:
+ api = HfApi(endpoint=endpoint)
+ info = api.model_info(model_id, revision=revision)
+ filenames = set(s.rfilename for s in info.siblings)
+
+ library_name = getattr(info, "library_name", None)
+ if any(filename.endswith(".safetensors") for filename in filenames) and not force:
+ print(f"Model {model_id} is already converted. Downloading safetensors...")
+ save_path = snapshot_download( # Download an entire directory, including the tokenizer config
+ model_id,
+ revision=revision,
+ allow_patterns=["*.safetensors", "*.json", "*.model"],
+ token=api.token,
+ cache_dir=folder,
+ endpoint=endpoint,
+ )
+ else:
+ snapshot_download( # Download an entire directory, including the tokenizer config
+ model_id,
+ revision=revision,
+ allow_patterns=["*.bin", "*.json", "*.model"],
+ token=api.token,
+ cache_dir=folder,
+ endpoint=endpoint,
+ )
+ if library_name == "transformers":
+ discard_names = get_discard_names(
+ model_id, revision=revision, folder=folder, token=api.token, endpoint=endpoint
+ )
+ if "pytorch_model.bin" in filenames:
+ save_path = convert_single(
+ model_id,
+ revision=revision,
+ folder=folder,
+ token=api.token,
+ discard_names=discard_names,
+ endpoint=endpoint,
+ )
+ elif "pytorch_model.bin.index.json" in filenames:
+ save_path = convert_multi(
+ model_id,
+ revision=revision,
+ folder=folder,
+ token=api.token,
+ discard_names=discard_names,
+ endpoint=endpoint,
+ )
+ else:
+ raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
+ else:
+ save_path = convert_generic(
+ model_id, revision=revision, folder=folder, filenames=filenames, token=api.token, endpoint=endpoint
+ )
+ return save_path
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Downloads and converts weights to `safetensors` format.")
+ parser.add_argument(
+ "model_id",
+ type=str,
+ help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ help="The revision to convert",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ help="The output directory to download and save the converted model",
+ )
+ parser.add_argument(
+ "--endpoint",
+ type=str,
+ default="https://hf-mirror.com",
+ help="The Huggingface endpoint to use. Defaults to `https://hf-mirror.com`.",
+ )
+ parser.add_argument(
+ "--force",
+ action="store_true",
+ help="Force weights re-conversion.",
+ )
+ parser.add_argument(
+ "--disable_ssl_verify",
+ action="store_true",
+ help="Disable SSL verification when downloading the model weights.",
+ )
+
+ args = parser.parse_args()
+ if args.disable_ssl_verify:
+ configure_http_backend(backend_factory=backend_factory)
+
+ path = convert(
+ args.model_id, revision=args.revision, folder=args.output_dir, force=args.force, endpoint=args.endpoint
+ )
+ print(f"Converted weights saved to {os.path.dirname(os.path.dirname(path))}")
diff --git a/examples/moviegen/tools/patch_pynative.sh b/examples/moviegen/tools/patch_pynative.sh
new file mode 100644
index 0000000000..516685b0b5
--- /dev/null
+++ b/examples/moviegen/tools/patch_pynative.sh
@@ -0,0 +1,24 @@
+# Patch MindSpore to add support for recompute in PyNative mode
+
+# Find the site-packages path
+SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])")
+
+# Define the file path and the line to insert
+FILE_PATH="$SITE_PACKAGES/mindspore/common/recompute.py"
+LINE_AFTER=" self.wrap_cell = _WrapCell(block)"
+LINE_TO_INSERT=" self.wrap_cell.set_inputs()"
+
+# Check if the file has already been modified
+if grep -qF "$LINE_TO_INSERT" "$FILE_PATH"; then
+ echo "File $FILE_PATH has already been patched. No changes made."
+ exit 0
+fi
+
+# Use sed to insert the line after the specified pattern
+if sed -i "/$LINE_AFTER/a \\$LINE_TO_INSERT" "$FILE_PATH"
+then
+ echo "Successfully patched $FILE_PATH"
+else
+ echo "Error: Failed to patch $FILE_PATH"
+ exit 1
+fi
diff --git a/examples/opensora_hpcai/scripts/train.py b/examples/opensora_hpcai/scripts/train.py
index a645af3da7..dbfbc4ba50 100644
--- a/examples/opensora_hpcai/scripts/train.py
+++ b/examples/opensora_hpcai/scripts/train.py
@@ -787,7 +787,7 @@ def main(args):
log_interval=args.log_interval,
start_epoch=start_epoch,
model_name=model_name,
- resume_prefix_blacklist=["vae.", "swap."],
+ resume_prefix_blacklist=("vae.", "swap."),
record_lr=False,
train_steps=args.train_steps,
)
diff --git a/examples/svd/train.py b/examples/svd/train.py
index 5e8becf50c..c13a9b942e 100644
--- a/examples/svd/train.py
+++ b/examples/svd/train.py
@@ -179,7 +179,7 @@ def main(args, initializer):
parser.add_function_arguments(
create_dataloader,
"train.dataloader",
- skip={"dataset", "transforms", "device_num", "rank_id", "debug", "enable_modelarts"},
+ skip={"dataset", "transforms", "batch_transforms", "device_num", "rank_id", "debug", "enable_modelarts"},
)
parser.add_function_arguments(create_scheduler, "train.scheduler", skip={"steps_per_epoch", "num_epochs"})
parser.add_function_arguments(create_optimizer, "train.optimizer", skip={"params", "lr"})
diff --git a/examples/t2i_adapter/train_t2i_adapter_sd.py b/examples/t2i_adapter/train_t2i_adapter_sd.py
index 41971a6603..af015534f5 100644
--- a/examples/t2i_adapter/train_t2i_adapter_sd.py
+++ b/examples/t2i_adapter/train_t2i_adapter_sd.py
@@ -155,7 +155,7 @@ def main(args, initializer):
parser.add_function_arguments(
create_dataloader,
"train.dataloader",
- skip={"dataset", "transforms", "device_num", "rank_id", "debug", "enable_modelarts"},
+ skip={"dataset", "transforms", "batch_transforms", "device_num", "rank_id", "debug", "enable_modelarts"},
)
parser.add_function_arguments(build_optimizer, "train.optimizer", skip={"model"})
parser.add_class_arguments(
diff --git a/mindone/data/loader.py b/mindone/data/loader.py
index dd394f5193..ce4912a1b3 100644
--- a/mindone/data/loader.py
+++ b/mindone/data/loader.py
@@ -3,6 +3,7 @@
import mindspore as ms
from mindspore.communication import get_local_rank, get_local_rank_size
+from ..utils.version_control import MS_VERSION
from .dataset import BaseDataset
@@ -10,6 +11,7 @@ def create_dataloader(
dataset: BaseDataset,
batch_size: int = 1,
transforms: Optional[Union[List[dict], dict]] = None,
+ batch_transforms: Optional[Union[List[dict], dict]] = None,
project_columns: Optional[List[str]] = None,
shuffle: bool = False,
num_workers: int = 4,
@@ -18,7 +20,7 @@ def create_dataloader(
drop_remainder: bool = True,
python_multiprocessing: bool = True,
prefetch_size: int = 16,
- max_rowsize: int = 64,
+ max_rowsize: Optional[int] = None,
device_num: int = 1,
rank_id: int = 0,
debug: bool = False,
@@ -37,6 +39,8 @@ def create_dataloader(
"input_columns": [List of columns to apply transforms to], # Optional
"output_columns": [List of output columns] # Optional, only used if different from the `input columns`
}
+ batch_transforms: Optional transformations to apply to the dataset. Identical to `transforms` but applied to
+ batches.
project_columns: Optional list of output columns names from transformations.
These names can be used for column selection or sorting in a specific order.
shuffle: Whether to randomly sample data. Default is False.
@@ -48,8 +52,14 @@ def create_dataloader(
python_multiprocessing: Whether to use Python multiprocessing for data transformations. This option could be
beneficial if the Python operation is computational heavy. Default is True.
prefetch_size: The number of samples to prefetch (per device). Default is 16.
- max_rowsize: Maximum size of row in MB that is used for shared memory allocation to copy data between processes.
- This is only used if `python_multiprocessing` is set to `True`. Default is 64.
+ max_rowsize: Maximum size of row in MB for shared memory allocation to copy data among processes.
+ This is only used if `python_multiprocessing` is set to `True`.
+ Values:
+ - `None` (default):
+ - For MindSpore 2.3 and above: Uses -1 (dynamic allocation).
+ - For MindSpore 2.2 and below: Uses 64MB.
+ - `-1`: (MindSpore 2.3+ only) Allocates memory dynamically.
+ - Positive integer: Sets a specific maximum row size in MB.
device_num: The number of devices to distribute the dataset across. Default is 1.
rank_id: The rank ID of the current device. Default is 0.
debug: Whether to enable debug mode. Default is False.
@@ -80,8 +90,12 @@ def create_dataloader(
shuffle=shuffle,
)
+ if max_rowsize is None:
+ # MS 2.3 and above: allocate memory dynamically
+ max_rowsize = -1 if MS_VERSION >= "2.3" else 64
+
if transforms is not None:
- if not isinstance(transforms, list):
+ if isinstance(transforms, dict):
transforms = [transforms]
for transform in transforms:
@@ -108,5 +122,16 @@ def create_dataloader(
dataloader = dataloader.batch(
batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_workers_batch
)
+ if batch_transforms is not None:
+ if isinstance(batch_transforms, dict):
+ batch_transforms = [batch_transforms]
+
+ for batch_transform in batch_transforms:
+ dataloader = dataloader.map(
+ **batch_transform,
+ python_multiprocessing=python_multiprocessing,
+ num_parallel_workers=num_workers,
+ max_rowsize=max_rowsize,
+ )
return dataloader
diff --git a/mindone/models/modules/parallel/__init__.py b/mindone/models/modules/parallel/__init__.py
index e3f77a9537..101c1a958a 100644
--- a/mindone/models/modules/parallel/__init__.py
+++ b/mindone/models/modules/parallel/__init__.py
@@ -1,7 +1,7 @@
-from mindspore import nn
+from mindspore import mint, nn
from .conv import Conv1d, Conv2d, Conv3d
-from .dense import Dense
+from .dense import Dense, Linear
# {Original MindSpore Cell: New Cell in ZeRO3}
PARALLEL_MODULES = {
@@ -9,5 +9,7 @@
nn.Conv2d: Conv2d,
nn.Conv3d: Conv3d,
nn.Dense: Dense,
+ mint.nn.Linear: Linear,
}
-__all__ = ["Conv1d", "Conv2d", "Conv3d", "Dense"]
+
+__all__ = ["Conv1d", "Conv2d", "Conv3d", "Dense", "Linear"]
diff --git a/mindone/models/modules/parallel/dense.py b/mindone/models/modules/parallel/dense.py
index 66ef7fef71..8d31690fff 100644
--- a/mindone/models/modules/parallel/dense.py
+++ b/mindone/models/modules/parallel/dense.py
@@ -1,4 +1,8 @@
-from mindspore import nn, ops
+from typing import Literal, Optional, Union
+
+from mindspore import Tensor
+from mindspore import dtype as mstype
+from mindspore import mint, nn, ops
from mindspore.communication import get_group_size, get_rank
from mindspore.communication.management import GlobalComm
from mindspore.context import ParallelMode
@@ -8,8 +12,14 @@
class Dense(nn.Cell):
- def __init__(self, net, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None):
- super(Dense, self).__init__(auto_prefix=False)
+ def __init__(
+ self,
+ net: Union[nn.Dense, mint.nn.Linear],
+ zero_stage: Literal[0, 1, 2, 3] = 0,
+ op_group: str = GlobalComm.WORLD_COMM_GROUP,
+ cell_type: Optional[mstype.Type] = None,
+ ):
+ super().__init__(auto_prefix=False)
self.net = net
self.set_param_wrapper(zero_stage, op_group, cell_type)
@@ -43,3 +53,8 @@ def construct(self, x):
out_shape = x_shape[:-1] + (x.shape[-1],)
x = x.reshape(out_shape)
return x
+
+
+class Linear(Dense):
+ def construct(self, x: Tensor) -> Tensor:
+ return self.net.dense(x, self.param_wrapper_w(self.net.weight), self.param_wrapper_b(self.net.bias))
diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py
index 1bdbd535e5..5e0cc15244 100755
--- a/mindone/trainers/callback.py
+++ b/mindone/trainers/callback.py
@@ -1,13 +1,14 @@
import logging
import os
import time
-from typing import List
+from typing import List, Literal, Optional, Tuple, Union
-import mindspore as ms
+from mindspore import Profiler, Tensor, nn, save_checkpoint
from mindspore.communication import get_rank
from mindspore.train.callback._callback import Callback, _handle_loss
from .checkpoint import CheckpointManager
+from .ema import EMA
from .recorder import PerfRecorder
_logger = logging.getLogger("")
@@ -23,7 +24,7 @@ def get_real_rank():
return int(os.getenv("RANK_ID", "0"))
-class OverflowMonitor(ms.Callback):
+class OverflowMonitor(Callback):
def on_train_step_end(self, run_context):
cb_params = run_context.original_args()
cur_epoch_num = cb_params.get("cur_epoch_num", 1)
@@ -38,38 +39,40 @@ def on_train_step_end(self, run_context):
class EvalSaveCallback(Callback):
def __init__(
self,
- network,
- use_lora=False,
- rank_id=0,
- ckpt_save_dir="./",
- output_dir=None,
- ema=None,
- save_ema_only=True,
- ckpt_save_policy="lastest_k",
- ckpt_max_keep=10,
- step_mode=False,
- ckpt_save_interval=1,
- use_step_unit=False,
- data_sink_mode=True,
- lora_rank=None,
- log_interval=1,
- start_epoch=0,
- record_lr=True,
- model_name="sd",
+ network: nn.Cell,
+ use_lora: bool = False,
+ rank_id: int = 0,
+ ckpt_save_dir: str = "./",
+ output_dir: str = None,
+ ema: EMA = None,
+ save_ema_only: bool = True,
+ ckpt_save_policy: Literal["top_k", "latest_k", None] = "latest_k",
+ monitor_metric: Optional[str] = None,
+ ckpt_max_keep: int = 10,
+ step_mode: bool = False,
+ ckpt_save_interval: int = 1,
+ use_step_unit: bool = False,
+ data_sink_mode: bool = True,
+ lora_rank: Optional[int] = None,
+ log_interval: int = 1,
+ start_epoch: int = 0,
+ record_lr: bool = True,
+ model_name: str = "sd",
save_trainable_only: bool = False,
param_save_filter: List[str] = None,
- resume_prefix_blacklist: List[str] = None,
- integrated_save=False,
- save_training_resume=True,
- train_steps=-1,
- prefer_low_perf=False,
+ resume_prefix_blacklist: Optional[Union[str, Tuple[str, ...]]] = None,
+ integrated_save: bool = False,
+ save_training_resume: bool = True,
+ train_steps: int = -1,
+ prefer_low_perf: bool = False,
):
"""
Args:
step_mode: if True, ckpt_save_interval is counted in steps. otherwise, in epochs.
param_save_filter: indicates what parameters to save in checkpoint. If None, save all parameters in network. \
Otherwise, only params that contain one of the keyword in param_save_filter list will be saved.
- resume_prefix_blacklist: exclude parameters with one of these prefixes to be saved in resume checkpoint. e.g. ['swap.', 'vae.'].
+ resume_prefix_blacklist: exclude parameters with one of these prefixes to be saved in resume checkpoint,
+ e.g. ('swap.', 'vae.').
"""
self.rank_id = rank_id
self.is_main_device = rank_id in [0, None]
@@ -96,6 +99,7 @@ def __init__(
if self.is_main_device:
self.ckpt_save_policy = ckpt_save_policy
+ self.monitor_metric = monitor_metric
self.ckpt_manager = CheckpointManager(
ckpt_save_dir,
ckpt_save_policy,
@@ -133,17 +137,11 @@ def __init__(
self.use_step_unit = use_step_unit
self.train_steps = train_steps
self.save_training_resume = save_training_resume
- if resume_prefix_blacklist is not None:
-
- def choice_func(x):
- for prefix in resume_prefix_blacklist:
- if x.startswith("vae."):
- return False
- return True
-
- self.choice_func = choice_func
- else:
- self.choice_func = None
+ self.choice_func = None
+ if resume_prefix_blacklist:
+ if isinstance(resume_prefix_blacklist, str):
+ resume_prefix_blacklist = (resume_prefix_blacklist,)
+ self.choice_func = lambda x: not x.startswith(resume_prefix_blacklist)
def on_train_step_end(self, run_context):
cb_params = run_context.original_args()
@@ -170,23 +168,27 @@ def on_train_step_end(self, run_context):
)
append_dict = {"lora_rank": self.lora_rank} if self.use_lora else None
- if self.ema is not None:
- if not self.save_ema_only:
- self.ckpt_manager.save(
- self.net_to_save,
- None,
- ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"),
- append_dict=append_dict,
- )
- # swap ema weight and network weight
- self.ema.swap_before_eval()
-
- # save history checkpoints
- self.ckpt_manager.save(self.net_to_save, None, ckpt_name=ckpt_name, append_dict=append_dict)
+ perf = cb_params.get("eval_results")
+ if perf or self.ckpt_save_policy != "top_k":
+ if perf:
+ perf = perf[self.monitor_metric]
+ if self.ema is not None:
+ if not self.save_ema_only:
+ self.ckpt_manager.save(
+ self.net_to_save,
+ perf,
+ ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"),
+ append_dict=append_dict,
+ )
+ # swap ema weight and network weight
+ self.ema.swap_before_eval()
+
+ # save history checkpoints
+ self.ckpt_manager.save(self.net_to_save, perf, ckpt_name=ckpt_name, append_dict=append_dict)
if self.save_training_resume:
# TODO: resume training for step.
- ms.save_checkpoint(
+ save_checkpoint(
cb_params.train_network,
os.path.join(self.ckpt_save_dir, "train_resume.ckpt"),
choice_func=self.choice_func,
@@ -284,7 +286,7 @@ def on_train_epoch_end(self, run_context):
)
if self.save_training_resume:
- ms.save_checkpoint(
+ save_checkpoint(
cb_params.train_network,
os.path.join(self.ckpt_save_dir, "train_resume.ckpt"),
choice_func=self.choice_func,
@@ -330,7 +332,7 @@ def _get_scaling_value_from_cbp(self, cb_params):
else:
return cb_params.train_network.scale_sense.asnumpy().item()
- def _fetch_optimizer_lr(self, cb_params) -> ms.Tensor:
+ def _fetch_optimizer_lr(self, cb_params) -> Tensor:
opt = self._get_optimizer_from_cbp(cb_params)
lr = opt.learning_rate
if opt.dynamic_lr:
@@ -338,7 +340,7 @@ def _fetch_optimizer_lr(self, cb_params) -> ms.Tensor:
return lr
-class StopAtStepCallback(ms.Callback):
+class StopAtStepCallback(Callback):
# stop the training process when reach train_steps
def __init__(self, train_steps, global_step=0):
self.global_step = global_step
@@ -350,7 +352,7 @@ def on_train_step_end(self, run_context):
run_context.request_stop()
-class ProfilerCallback(ms.Callback):
+class ProfilerCallback(Callback):
def __init__(self, start_step=1, end_step=2, exit_after_analyze=True, out_dir="./profiler_data"):
self.start_step = start_step
self.end_step = end_step
@@ -359,7 +361,7 @@ def __init__(self, start_step=1, end_step=2, exit_after_analyze=True, out_dir=".
out_dir = os.path.join(out_dir, f"rank_{rank_id}")
# If value of profile_framework is not None, a subdirectory named host_info will be generated under the
# specified profiler directory to store the collected memory and time files on the Host side.
- self.profiler = ms.Profiler(
+ self.profiler = Profiler(
start_profile=False, output_path=out_dir, profile_framework="all", data_simplication=False
)
@@ -381,12 +383,12 @@ def on_train_step_end(self, run_context):
run_context.request_stop()
-class ProfilerCallbackEpoch(ms.Callback):
+class ProfilerCallbackEpoch(Callback):
def __init__(self, start_epoch, stop_epoch, output_dir="./profiler_data"):
super().__init__()
self.start_epoch = start_epoch
self.stop_epoch = stop_epoch
- self.profiler = ms.Profiler(start_profile=False, output_path=output_dir)
+ self.profiler = Profiler(start_profile=False, output_path=output_dir)
def on_train_epoch_begin(self, run_context):
cb_params = run_context.original_args()
diff --git a/mindone/trainers/ema.py b/mindone/trainers/ema.py
index 45fd95c2b5..ca702e5c47 100644
--- a/mindone/trainers/ema.py
+++ b/mindone/trainers/ema.py
@@ -3,6 +3,8 @@
from mindspore.ops import composite as C
from mindspore.ops import functional as F
+__all__ = ["EMA"]
+
_ema_op = C.MultitypeFuncGraph("grad_ema_op")
@@ -18,7 +20,14 @@ class EMA(nn.Cell):
offloading: if True, offload the assign computation to CPU to avoid OOM issue.
"""
- def __init__(self, network, ema_decay=0.9999, updates=0, trainable_only=True, offloading=True):
+ def __init__(
+ self,
+ network: nn.Cell,
+ ema_decay: float = 0.9999,
+ updates: int = 0,
+ trainable_only: bool = True,
+ offloading: bool = True,
+ ):
super().__init__()
# TODO: net.trainable_params() is more reasonable?
if trainable_only:
diff --git a/mindone/trainers/train_step.py b/mindone/trainers/train_step.py
index 519347948c..db86a98a1c 100644
--- a/mindone/trainers/train_step.py
+++ b/mindone/trainers/train_step.py
@@ -1,4 +1,5 @@
"""Train step wrapper supporting setting drop overflow update, ema etc"""
+from typing import Optional
from packaging import version
@@ -13,6 +14,8 @@
from mindspore.ops import functional as F
from mindspore.ops import operations as P
+from .ema import EMA
+
_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
@@ -55,7 +58,7 @@ def __init__(
network,
optimizer,
scale_sense=1.0,
- ema=None,
+ ema: Optional[EMA] = None,
updates=0,
drop_overflow_update=True,
gradient_accumulation_steps=1,
@@ -98,6 +101,10 @@ def __init__(
if gradient_accumulation_steps > 1:
self.accumulated_grads = optimizer.parameters.clone(prefix="grad_accumulated_", init="zeros")
+ def set_train(self, mode: bool = True):
+ # Delegate the setting of training mode behavior to the network.
+ self.network.set_train(mode)
+
def construct(self, *inputs):
# compute loss
weights = self.weights
diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py
index fceff1b3fa..42bbc4e326 100644
--- a/mindone/trainers/zero.py
+++ b/mindone/trainers/zero.py
@@ -1,6 +1,7 @@
import json
import logging
import os
+from typing import Literal
import mindspore as ms
from mindspore import nn, ops
@@ -554,13 +555,13 @@ def prepare_train_network(
clip_grad: bool = False,
clip_norm: float = 1.0,
verbose: bool = False,
- zero_stage: int = 0,
+ zero_stage: Literal[0, 1, 2, 3] = 0,
optimizer_offload: bool = False,
op_group: str = None,
dp_group: str = None,
comm_fusion: dict = None,
parallel_modules=None,
-):
+) -> TrainOneStepWrapper:
"""
Prepare network and optimizer for distributed training.
diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py
index b327d8709b..7d413cf204 100644
--- a/mindone/transformers/modeling_utils.py
+++ b/mindone/transformers/modeling_utils.py
@@ -1301,7 +1301,7 @@ def from_pretrained(
state_dict = kwargs.pop("state_dict", None)
from_tf = kwargs.pop("from_tf", False)
from_flax = kwargs.pop("from_flax", False)
- resume_download = kwargs.pop("resume_download", False)
+ resume_download = kwargs.pop("resume_download", None)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
use_auth_token = kwargs.pop("use_auth_token", None)
diff --git a/mindone/transformers/models/t5/modeling_t5.py b/mindone/transformers/models/t5/modeling_t5.py
index 0a72dc6374..2326c3446b 100644
--- a/mindone/transformers/models/t5/modeling_t5.py
+++ b/mindone/transformers/models/t5/modeling_t5.py
@@ -1072,7 +1072,7 @@ def construct(
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = False,
+ return_dict: Optional[bool] = None,
) -> Union[Tuple[ms.Tensor], Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1099,6 +1099,7 @@ def construct(
>>> logits = outputs[1]
```"""
use_cache = use_cache if use_cache is not None else self.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
diff --git a/mindone/utils/env.py b/mindone/utils/env.py
index 76a2b3072b..4117912026 100644
--- a/mindone/utils/env.py
+++ b/mindone/utils/env.py
@@ -10,6 +10,8 @@
import mindspore as ms
from mindspore.communication import get_group_size, get_rank, init
+from .version_control import MS_VERSION
+
_logger = logging.getLogger(__name__)
@@ -18,11 +20,11 @@ def init_train_env(
device_target: Literal["Ascend", "GPU"] = "Ascend",
debug: bool = False,
seed: int = 42,
- jit_level: str = "O0",
cache_graph: bool = False,
cache_path: str = "./cache",
distributed: bool = False,
ascend_config: Optional[dict] = None,
+ jit_level: Optional[Literal["O0", "O1", "O2"]] = None,
enable_modelarts: bool = False,
max_device_memory: str = None,
) -> Tuple[int, int, int]:
@@ -40,6 +42,8 @@ def init_train_env(
cache_path: The path to save or load the saved computation graph.
distributed: Whether to enable distributed training. Default is False.
ascend_config: Parameters specific to the Ascend hardware platform.
+ jit_level: The compilation optimization level. Options: "O0", "O1", "O2".
+ Default is None and the level selected based on the device.
enable_modelarts: Whether to enable modelarts (OpenI) support. Default is False.
max_device_memory (str, default: None): The maximum amount of memory that can be allocated on the Ascend device.
@@ -48,31 +52,23 @@ def init_train_env(
"""
ms.set_seed(seed)
- if mode == ms.GRAPH_MODE:
- try:
- if jit_level in ["O0", "O1", "O2"]:
- ms.set_context(jit_config={"jit_level": jit_level})
- _logger.info(f"set jit_level: {jit_level}.")
- else:
- _logger.warning(
- f"Unsupport jit_level: {jit_level}. The framework automatically selects the execution method"
- )
- except Exception:
- _logger.warning(
- "The current jit_level is not suitable because current MindSpore version does not match,"
- "please ensure the MindSpore version >= ms2.3.0."
- )
-
if debug and mode == ms.GRAPH_MODE: # force PyNative mode when debugging
_logger.warning("Debug mode is on, switching execution mode to PyNative.")
mode = ms.PYNATIVE_MODE
if max_device_memory is not None:
ms.set_context(max_device_memory=max_device_memory)
+ if jit_level:
+ if MS_VERSION >= "2.3":
+ ms.set_context(jit_config={"jit_level": jit_level})
+ else:
+ _logger.warning("Compilation optimization (JIT Level) is supported only in MindSpore 2.3 or later.")
+
if distributed:
ms.set_context(mode=mode, device_target=device_target, ascend_config=ascend_config or {})
device_id = os.getenv("DEVICE_ID", None)
if device_id:
ms.set_context(device_id=int(device_id))
+
init()
device_num = get_group_size()
rank_id = get_rank()