Skip to content

Commit

Permalink
Add MovieGen - TAE (#778)
Browse files Browse the repository at this point in the history
* add tae

* linting

* rm redundancy

* fix inflate

* add dev plan

* Update README.md

* update psnr w/ opl

* add report

* Update report.md

* Update report.md

* Update report.md

* Update report.md

* Update report.md

* fix linting

* add transformer to tech report

* Update report.md

* linting

* rename to moviegen

* rm irrelevant files

* add report link

* use mint.pad

* linting
  • Loading branch information
SamitHuang authored Jan 2, 2025
1 parent 72ba26e commit e129159
Show file tree
Hide file tree
Showing 25 changed files with 5,554 additions and 1 deletion.
5 changes: 4 additions & 1 deletion examples/animatediff/ad/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,10 @@ def construct(self, x: ms.Tensor, text_tokens: ms.Tensor, control=None, **kwargs
"""

# 1. get image/video latents z using vae
z = self.get_latents(x)
if self.emb_cache:
z = x
else:
z = self.get_latents(x)

# 2. sample timestep and add noise to latents
t = self.uniform_int(
Expand Down
151 changes: 151 additions & 0 deletions examples/moviegen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Movie Gen Video based on MindSpore

This project is built on the [Movie Gen](https://arxiv.org/abs/2410.13720) paper by Meta for video generation, personalization, and editing. We aim to explore an efficient implementation based on MindSpore and Ascend NPUs. See our [report](docs/report.md) for more details.

## 📑 Development Plan

This project is in an early stage and under active development. We welcome the open-source community to contribute to this project!

- Temporal Autoencoder (TAE)
- [x] Inference
- [x] Training
- MovieGenVideo-5B (T2I/V)
- [x] Inference
- [x] Training stage 1: T2I 256px
- [x] Training stage 2: T2I/V 256px 256frames
- [ ] Training stage 3: T2I/V 768px 256frames (under training)
- [x] Web Demo (Gradio)
- MovieGenVideo-30B (T2I/V)
- [x] Inference
- [ ] Mixed parallelism training (support DP+SP+CP+TP+MP+Zero3, under training)
- Personalized-MovieGenVideo (PT2V)
- [ ] Inference
- [ ] Training
- MovieGen-Edit
- [ ] Inference
- [ ] Training


## Temporal Autoencoder (TAE)


### Requirements

| mindspore | ascend driver | firmware |cann toolkit/kernel |
|:----------:|:--------------:|:-----------:|:------------------:|
| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |

### Prepare weights

We use SD3.5 VAE to initialize the spatial layers of TAE, considering they have the same number of latent channels, i.e. 16.

1. Download SD3.5 VAE from [huggingface](https://huggingface.co/stabilityai/stable-diffusion-3.5-large/tree/main/vae)

2. Inflate VAE checkpoint for TAE initialization by

```shell
python inflate_vae_to_tae.py --src /path/to/sd3.5_vae/diffusion_pytorch_model.safetensors --target models/tae_vae2d.ckpt
```

### Prepare datasets

We need to prepare a csv annotation file listing the path to each input video related to the root folder, indicated by the `video_folder` argument. An example is
```
video
dance/vid001.mp4
dance/vid002.mp4
dance/vid003.mp4
...
```

Taking UCF-101 for example, please download the [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) dataset and extract it to `datasets/UCF-101` folder.


### Training

TAE is trained to optimize the reconstruction loss, perceptual loss, and the outlier penalty loss (OPL) proposed in the MovieGen paper.

To launch training, please run

```shell
python scripts/train_tae.py \
--config configs/tae/train/mixed_256x256x32.yaml \
--output_path /path/to/save_ckpt_and_log \
--csv_path /path/to/video_train.csv \
--video_folder /path/to/video_root_folder \
```

Unlike the paper, we found that OPL loss doesn't benefit the training outcome in our ablation study (w/ OPl PSNR is 31.17). Thus we disable OPL loss by default. You may enable it by appending `--use_outlier_penalty_loss True`

For more details on the arguments, please run `python scripts/train_tae.py --help`


### Evaluation

To run video reconstruction with the trained TAE model and evaluate the PSNR and SSIM on the test set, please run

```shell
python scripts/inference_tae.py \
--ckpt_path /path/to/tae.ckpt \
--batch_size 2 \
--num_frames 32 \
--image_size 256 \
--csv_path /path/to/video_test.csv \
--video_folder /path/to/video_root_folder \
```

The reconstructed videos will be saved in `samples/recons`.

#### Performance

Here, we report the training performance and evaluation results on the UCF-101 dataset.

Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.

| model name | cards | batch size | resolution | precision | jit level | graph compile | s/step | PSNR | SSIM | recipe |
| :--: | :---: | :--: | :--: | :--: | :--: | :--: |:--: | :--: |:--: |:--: |
| TAE | 1 | 1 | 256x256x32 | bf16 | O0 | 2 mins | 2.18 | 31.35 | 0.92 | [config](configs/tae/train/mixed_256x256x32.yaml) |


### Usages for Latent Diffusion Models

<details>
<summary>View more</summary>

#### Encoding video

```python
from mg.models.tae.tae import TemporalAutoencoder, TAE_CONFIG

# may set use_tile=True to save memory
tae = TemporalAutoencoder(
pretrained='/path/to/tae.ckpt',
use_tile=False,
)

# x - a batch of videos, shape (b c t h w)
z, _, _ = tae.encode(x)


# you may scale z by:
# z = TAE_CONFIG['scaling_factor'] * (z - TAE_CONFIG['shift_factor'])

```

For detailed arguments, please refer to the docstring in [tae.py](mg/models/tae/tae.py)

### Decoding video latent

```python

# if z is scaled, you should unscale at first:
# z = z / TAE_CONFIG['scaling_factor'] + TAE_CONFIG['shift_factor']

# z - a batch of video latent, shape (b c t h w)
x = tae.decode(z)

# for image decoding, set num_target_frames to discard the spurious frames
x = tae.decode(z, num_target_frames=1)
```

</details>
46 changes: 46 additions & 0 deletions examples/moviegen/configs/tae/train/mixed_256x256x16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# model
pretrained_model_path: "models/tae_vae2d.ckpt"

# loss
perceptual_loss_weight: 1.0
kl_loss_weight: 1.e-6
use_outlier_penalty_loss: False # OPL bring no benefit in our experiments
mixed_strategy: "mixed_video_image"
mixed_image_ratio: 0.2

# data
dataset_name: "video"
csv_path: "../videocomposer/datasets/webvid5_copy.csv"
video_folder: "../videocomposer/datasets/webvid5"
frame_stride: 1
num_frames: 16
image_size: 256
crop_size: 256
# flip: True

# training recipe
seed: 42
use_discriminator: False
batch_size: 1
clip_grad: True
max_grad_norm: 1.0
start_learning_rate: 1.e-5
scale_lr: False
weight_decay: 0.

dtype: "fp32"
amp_level: "O0"
use_recompute: False

epochs: 2000
ckpt_save_interval: 50
init_loss_scale: 1024.
loss_scaler_type: dynamic

scheduler: "constant"
use_ema: False

output_path: "outputs/tae_train"

# ms settting
jit_level: O0
46 changes: 46 additions & 0 deletions examples/moviegen/configs/tae/train/mixed_256x256x32.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# model
pretrained_model_path: "models/tae_vae2d.ckpt"

# loss
perceptual_loss_weight: 1.0
kl_loss_weight: 1.e-6
use_outlier_penalty_loss: False # OPL bring no benefit in our experiments
mixed_strategy: "mixed_video_image"
mixed_image_ratio: 0.2

# data
dataset_name: "video"
csv_path: "../videocomposer/datasets/webvid5_copy.csv"
video_folder: "../videocomposer/datasets/webvid5"
frame_stride: 1
num_frames: 32
image_size: 256
crop_size: 256
# flip: True

# training recipe
seed: 42
use_discriminator: False
batch_size: 1
clip_grad: True
max_grad_norm: 1.0
start_learning_rate: 1.e-5
scale_lr: False
weight_decay: 0.

dtype: "bf16"
amp_level: "O2" # reduce memory cost
use_recompute: True

epochs: 2000
ckpt_save_interval: 50
init_loss_scale: 1024.
loss_scaler_type: dynamic

scheduler: "constant"
use_ema: False

output_path: "outputs/tae_train"

# ms settting
jit_level: O0
Loading

0 comments on commit e129159

Please sign in to comment.