Zhicheng Zhang1,2# · Junyao Hu1,2# · Wentao Cheng1* · Danda Paudel3,4 · Jufeng Yang1,2
1 VCIP & TMCC & DISSec, College of Computer Science, Nankai University
2 Nankai International Advanced Research Institute (SHENZHEN · FUTIAN)
3 Computer Vision Lab, ETH Zurich 4 INSAIT, Sofia University
# Equal Contribution + Corresponding Author
🎉 Accepted by CVPR 2024 🎉
[📃 Paper ] [📃 中译版 ] [📦 Code ] [⚒️ Project ] [📊 Poster ] [📅 Slide ] [🎞️ Bilibili / YouTube ]
TL;DR: We present ExtDM, a new diffusion model that extrapolates video content from current frames by accurately modeling distribution shifts towards future frames.
- 🔥2024-06-19: The code, datasets, and model weights are releasing!
- 2024-03-21: Creating repository. The code is coming soon ...
- 2024-02-27: ExtDM has been accepted to CVPR 2024!
conda create -n ExtDM python=3.9
conda activate ExtDM
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install einops einops_exts rotary_embedding_torch==rotary-embedding-torch timm==0.4.5 imageio scikit-image opencv-python flow_vis matplotlib mediapy lpips h5py PyYAML tqdm wandb scipy==1.9.3
conda install ffmpeg
cd <your_path>/ExtDM
pip install -e .
Note: If you encounter the following error when running the code:
/home/<user_name>/anaconda3/envs/ExtDM/bin/ffmpeg: error while loading shared libraries: libopenh264.so.5: cannot open shared object file: No such file or directory
you can solve it by make a copy file like:
cp /home/<user_name>/anaconda3/envs/ExtDM/lib/libopenh264.so.6 /home/<user_name>/anaconda3/envs/ExtDM/lib/libopenh264.so.5
Overview of the preprocessed dataset:
Dataset | Len (train) | Len (test) | Avg. Frames (train) | Setup (c->p) | Link & Size |
---|---|---|---|---|---|
SMMNIST | 60000 | 256 | 40 | 10 -> 10 | google drive (688M) |
KTH | 479 | 120 (sample to 256) | 483.18 | 10 -> 30/40 | google drive (919M) |
BAIR | 43264 | 256 | 30 | 2 -> 14/28 | google drive (13G) |
Cityscapes | 2975 | 1525 (sample to 256) | 30 | 2 -> 28 | google drive (1.3G) |
UCF-101 | - | - | - | 4 -> 12 | google drive (40G) |
This script will automatically download the PyTorch MNIST dataset, which will be used to dynamically generate random move MNIST. The script will save the randomly generated content in the HDF5 dataset format.
How the data was processed:
cd <your_path>/ExtDM/data/SMMNIST
dataset_root=<your_data_path>/SMMNIST
python 01_mnist_download_and_convert.py --image_size 64 --mnist_dir $dataset_root --out_dir $dataset_root/processed --force_h5 False
How the data was processed:
cd <your_path>/ExtDM/data/KTH
dataset_root=<your_data_path>/KTH
sh 01_kth_download.sh $dataset_root
python 02_kth_train_val_test_split.py
python 03_kth_convert.py --split_dir ./ --image_size 64 --kth_dir $dataset_root/raw --out_dir $dataset_root/mixed_processed --force_h5 False
How the data was processed:
cd <your_path>/ExtDM/data/BAIR
dataset_root=<your_data_path>/BAIR
sh 01_bair_download.sh $dataset_root
python bair_convert.py --bair_dir $dataset_root/raw --out_dir $dataset_root/processed
How the data was processed:
MAKE SURE YOU HAVE ~657GB SPACE! 324GB for the zip file, and 333GB for the unzipped image files
- Download Cityscapes video dataset (
leftImg8bit_sequence_trainvaltest.zip
(324GB)) :
sh cityscapes_download.sh username password
using yourusername
andpassword
that you created on https://www.cityscapes-dataset.com/ - Convert it to HDF5 format, and save in
/path/to/Cityscapes<image_size>_h5
:
python datasets/cityscapes_convert.py --leftImg8bit_sequence_dir '/path/to/Cityscapes/leftImg8bit_sequence' --image_size 64 --out_dir '/path/to/Cityscapes64_h5'
How the data was processed:
MAKE SURE YOU HAVE ~20GB SPACE! 6.5GB for the zip file, and 8GB for the unzipped image files
- Download UCF-101 video dataset (
UCF101.rar
(6.5GB)) :
sh cityscapes_download.sh /download/dir
- Convert it to HDF5 format, and save in
/path/to/UCF101_h5
:
python datasets/ucf101_convert.py --out_dir /path/to/UCF101_h5 --ucf_dir /download/dir/UCF-101 --splits_dir /download/dir/ucfTrainTestlist
TODO
TODO
AE Training
-
check
./config/AE/[DATASET].yaml
: set proper params forroot_dir
,num_regions
,max_epochs
,num_repeats
,lr
,batch_size
, etc. -
run
sh ./scripts/AE/train_AE_[DATASET].sh
sh ./scripts/AE/train_AE_smmnist.sh sh ./scripts/AE/train_AE_kth.sh sh ./scripts/AE/train_AE_bair.sh sh ./scripts/AE/train_AE_cityscapes.sh sh ./scripts/AE/train_AE_ucf.sh
-
you can see your running exp dir in
./logs_training/AE/[DATASET]/[EXP_NAME]
, or see details in wandb panels.
AE Inference
-
run
sh ./scripts/AE/train_AE_[DATASET].sh
sh ./scripts/AE/valid_AE_smmnist.sh sh ./scripts/AE/valid_AE_kth.sh sh ./scripts/AE/valid_AE_bair.sh sh ./scripts/AE/valid_AE_cityscapes.sh sh ./scripts/AE/valid_AE_ucf.sh
-
you can see your running exp dir in
./logs_validation/AE/[DATASET]/[EXP_NAME]
.
DM Training
-
check
./config/DM/[DATASET].yaml
: set proper params forroot_dir
,max_epochs
,num_repeats
,lr
,batch_size
, etc. -
run
sh ./scripts/DM/train_DM_[DATASET].sh
sh ./scripts/DM/train_DM_smmnist.sh sh ./scripts/DM/train_DM_kth.sh sh ./scripts/DM/train_DM_bair.sh sh ./scripts/DM/train_DM_cityscapes.sh sh ./scripts/DM/train_DM_ucf.sh
-
you can see your running exp dir in
./logs_training/DM/[DATASET]/[EXP_NAME]
, or see details in wandb panels.
DM Inference
-
run
sh ./scripts/DM/train_DM_[DATASET].sh
sh ./scripts/DM/valid_DM_smmnist.sh sh ./scripts/DM/valid_DM_kth.sh sh ./scripts/DM/valid_DM_bair.sh sh ./scripts/DM/valid_DM_cityscapes.sh sh ./scripts/DM/valid_DM_ucf.sh
-
you can see your running exp dir in
./logs_validation/DM/[DATASET]/[EXP_NAME]
.
If you have any questions, please feel free to contact:
- Zhicheng Zhang: [email protected]
- Junyao Hu: [email protected]
If you find this project useful, please consider citing:
@inproceedings{zhang2024ExtDM,
title={ExtDM: Distribution Extrapolation Diffusion Model for Video Prediction},
author={Zhang, Zhicheng and Hu, Junyao and Cheng, Wentao and Paudel, Danda and Yang, Jufeng},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (CVPR)},
year={2024}
}
This code borrows from CVPR23_LFDM (by @nihaomiao). The datasets partly comes from mcvd-pytorch (by @voletiv).