Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Movie Gen #783

Open
wants to merge 131 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 119 commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
111200d
fix config names
SamitHuang Aug 6, 2024
94614d9
add mem monitor
SamitHuang Aug 16, 2024
f784d36
update
SamitHuang Aug 19, 2024
0586212
update
SamitHuang Aug 19, 2024
fab1870
Merge branch 'master' of github.com:SamitHuang/mindone
SamitHuang Oct 14, 2024
828531d
Merge branch 'master' of github.com:SamitHuang/mindone
SamitHuang Oct 15, 2024
929b7ae
debug tae attn
SamitHuang Oct 22, 2024
a50a1cf
update
SamitHuang Oct 25, 2024
3676a9f
useless change
SamitHuang Oct 25, 2024
a5655fb
continue work from llama3_movie_pr_20241029
zhtmike Oct 29, 2024
5b77353
add parallel test case for scheduler and fix some minor bug
zhtmike Oct 29, 2024
c9bb319
add train script
zhtmike Oct 29, 2024
bb01342
move config file outside the folder
zhtmike Oct 29, 2024
3d279be
temp save
SamitHuang Oct 30, 2024
2cc907c
change some ops to mint
zhtmike Oct 30, 2024
888a213
add init for text projector
zhtmike Oct 30, 2024
d5a6406
fix mint
zhtmike Oct 31, 2024
14f0a34
fix type
zhtmike Oct 31, 2024
231be44
encoder ok
SamitHuang Nov 4, 2024
539437a
add image support to OS data loader
hadipash Oct 30, 2024
da5576b
update convert script
hadipash Oct 29, 2024
ee2acef
add recompute support in PyNative
hadipash Oct 29, 2024
59ed0d5
add dataloader
hadipash Oct 29, 2024
2fe103b
update train script
hadipash Oct 30, 2024
52822b0
add OSv1.2 VAE
hadipash Nov 4, 2024
d3ae9e3
fixes
hadipash Nov 5, 2024
e7a5523
Merge pull request #4 from hadipash/movie_gen
zhtmike Nov 5, 2024
798698c
reconstruct tested
SamitHuang Nov 5, 2024
8648df1
update readme
SamitHuang Nov 5, 2024
a6b5a49
discard spurious frames
SamitHuang Nov 6, 2024
3f16672
rename
SamitHuang Nov 6, 2024
78afc63
Merge branch 'tae' of github.com:SamitHuang/mindone into tae
SamitHuang Nov 6, 2024
df2f01c
add train
SamitHuang Nov 6, 2024
69865d4
add train config
SamitHuang Nov 6, 2024
7bfba90
rename
SamitHuang Nov 6, 2024
3db54f0
rename
SamitHuang Nov 6, 2024
5c2913f
add dataset
SamitHuang Nov 6, 2024
b8e1129
Merge branch 'tae' of github.com:SamitHuang/mindone into tae
SamitHuang Nov 6, 2024
4b706fb
trainable
SamitHuang Nov 7, 2024
9557e59
add inference
hadipash Nov 5, 2024
602d00e
fix opl loss
SamitHuang Nov 7, 2024
612efba
z 16
SamitHuang Nov 8, 2024
5414b0e
fix linear-quadratic sampling
hadipash Nov 11, 2024
89c3dbc
text encoders inference
hadipash Nov 11, 2024
90359e9
allow loading sd3.5 vae pretrained weights
SamitHuang Nov 13, 2024
9ebca22
Merge branch 'tae' of github.com:SamitHuang/mindone into tae
SamitHuang Nov 13, 2024
0f75248
update convert script
SamitHuang Nov 13, 2024
82ed8f7
add sd3 vae
SamitHuang Nov 13, 2024
9c0512f
add moduels for sd3 vae
SamitHuang Nov 13, 2024
a671816
update configs
hadipash Nov 13, 2024
835f9f0
temporal median init, 1p train psnr ok
SamitHuang Nov 15, 2024
5710f42
add files
SamitHuang Nov 15, 2024
164f0c9
fix rt id
SamitHuang Nov 15, 2024
4f7e3ef
set image and crop size
SamitHuang Nov 16, 2024
d830f5a
add train step mode
hadipash Nov 14, 2024
9372beb
replace interpolate for bf16 support
SamitHuang Nov 19, 2024
3eb6cb1
add validation support
hadipash Nov 14, 2024
1410f37
add ReduceLROnPlateau
hadipash Nov 19, 2024
625ee0d
save top K checkpoints
hadipash Nov 19, 2024
bf6988e
add drop text conditioning for training
hadipash Nov 20, 2024
1a2bfb7
fix eval loss calculation
hadipash Nov 20, 2024
04e11f3
add model parallel
hadipash Nov 21, 2024
8af7437
hack for model parallel
zhtmike Oct 29, 2024
bf505d4
fix hack
hadipash Nov 21, 2024
a048efb
small fixes
hadipash Nov 21, 2024
03e4271
add temporal tile
SamitHuang Nov 22, 2024
640871a
rm comments
SamitHuang Nov 22, 2024
1cd636e
clean code
SamitHuang Nov 22, 2024
cae87f5
draft readme and update decode
SamitHuang Nov 22, 2024
f64fa00
add config
SamitHuang Nov 22, 2024
961514b
add readme draft
SamitHuang Nov 22, 2024
e528c82
Merge remote-tracking branch 'Yongxiang/tae' into movie_gen
hadipash Nov 25, 2024
f083620
add TAE to Movie Gen
hadipash Nov 25, 2024
7bdac42
add buckets and dynamic graph support
hadipash Nov 25, 2024
37346d0
fix dynamic shape: defualt manual pad for conv1d same pad
SamitHuang Nov 26, 2024
ef3fa08
fix save callback and TAE scaling
hadipash Nov 25, 2024
0054700
Revert "fix hack"
hadipash Nov 26, 2024
28349a5
Revert "hack for model parallel"
hadipash Nov 26, 2024
855358e
Merge remote-tracking branch 'Yongxiang/tae' into movie_gen
hadipash Nov 26, 2024
84a25dd
revert it later
hadipash Nov 27, 2024
a260a1e
small fixes
hadipash Nov 27, 2024
ef72175
refactoring
hadipash Nov 28, 2024
5ef55ef
linting
hadipash Nov 28, 2024
ccf943a
add docs
hadipash Nov 28, 2024
3f7d207
refactor TAE
hadipash Dec 3, 2024
2b0673c
fix training with TAE latents
hadipash Dec 9, 2024
d763d6b
revert changes to OpenSora
hadipash Dec 10, 2024
2db9e2f
Merge branch 'master' into movie_gen
hadipash Dec 10, 2024
0911832
merge with PR #778
hadipash Dec 11, 2024
657628e
small fix
hadipash Dec 11, 2024
82cff9e
PR fixes:
hadipash Dec 11, 2024
a7d8a37
Update docs
hadipash Dec 11, 2024
0b77dfe
Update docs
hadipash Dec 11, 2024
e69f04d
update docs and small fixes
hadipash Dec 12, 2024
db10514
fix TAE encoding
hadipash Dec 12, 2024
aa60321
PR fixes:
hadipash Dec 13, 2024
ce10e7f
small inference fix
hadipash Dec 13, 2024
3076725
enable `lazy_inline`
hadipash Dec 16, 2024
8ba45df
small fix
hadipash Dec 17, 2024
8a9abfa
small fix
hadipash Dec 17, 2024
0392703
enable flexible recompute
hadipash Dec 18, 2024
3c33e66
enable flexible recompute
hadipash Dec 18, 2024
dc711f9
- add train resume feature
hadipash Dec 18, 2024
0776fdc
ResizeCrop fix
hadipash Dec 19, 2024
333faa3
update docs
hadipash Dec 19, 2024
d237f8f
support SP and change rms to ops.rms
zhtmike Dec 19, 2024
302df95
Gradio demo for MovieGen (#6)
itruonghai Dec 20, 2024
fcf7abc
Merge branch 'master' into movie_gen
hadipash Dec 20, 2024
7ee8118
update docs and add stage 3 configs
hadipash Dec 20, 2024
06ce9b4
add ZeRO-3 support to Movie Gen
hadipash Dec 24, 2024
fbe4e31
add Model Parallel
hadipash Dec 24, 2024
5aa1e4d
add technical report
hadipash Dec 24, 2024
de047db
update technical report
hadipash Dec 24, 2024
951a9d7
linting
hadipash Dec 24, 2024
88b5051
add inference without TAE and stand-alone decoding
hadipash Dec 24, 2024
f834cb4
Drop Model Parallel
hadipash Jan 2, 2025
b37fb7a
Merge remote-tracking branch 'Mike/movie_gen' into movie_gen
hadipash Jan 2, 2025
e0396c3
improve SP support
hadipash Jan 2, 2025
6b88764
fix checkpoint saving
hadipash Jan 2, 2025
b10da01
fix checkpoint saving
hadipash Jan 3, 2025
4881d03
align with PR#778
hadipash Jan 3, 2025
10c81e3
update README
zhtmike Jan 3, 2025
08d37c9
Merge remote-tracking branch 'origin/master' into movie_gen
hadipash Jan 3, 2025
fdd7045
Merge remote-tracking branch 'Mike/movie_gen' into movie_gen
hadipash Jan 3, 2025
7fa547e
resolve comments
hadipash Jan 6, 2025
fd6428f
fix SP
hadipash Jan 6, 2025
96a614d
small fixes and update README.md
hadipash Jan 6, 2025
9c43f9e
add TAE download link
hadipash Jan 6, 2025
8b97808
update README.md
hadipash Jan 6, 2025
d436fbe
fix imports and update README.md
hadipash Jan 6, 2025
6b364d3
update README.md
hadipash Jan 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 309 additions & 0 deletions examples/moviegen/README.md

Large diffs are not rendered by default.

182 changes: 182 additions & 0 deletions examples/moviegen/args_train_tae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import logging
import os
import sys

from jsonargparse import ActionConfigFile, ArgumentParser

# TODO: remove in future when mindone is ready for install
__dir__ = os.path.dirname(os.path.abspath(__file__))
mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../"))
sys.path.append(mindone_lib_path)

from mg.dataset.tae_dataset import BatchTransform, VideoDataset
from mg.models.tae import TemporalAutoencoder

from mindone.data import create_dataloader
from mindone.utils import init_train_env
from mindone.utils.misc import to_abspath

logger = logging.getLogger()


def parse_train_args():
parser = ArgumentParser(description="Temporal Autoencoder training script.")
parser.add_argument(
"-c",
action=ActionConfigFile,
help="Path to load a config yaml file that describes the setting which will override the default arguments.",
)
parser.add_function_arguments(
init_train_env, skip={"ascend_config", "num_workers", "json_data_path", "enable_modelarts"}
)
parser.add_class_arguments(TemporalAutoencoder, instantiate=False)
parser.add_argument(
"--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="TAE model precision."
)
parser.add_class_arguments(VideoDataset, skip={"output_columns"}, instantiate=False)
parser.add_class_arguments(BatchTransform, instantiate=False)
parser.add_function_arguments(
create_dataloader,
skip={"dataset", "transforms", "batch_transforms", "device_num", "rank_id", "debug", "enable_modelarts"},
)
parser.add_argument("--output_path", default="output/", type=str, help="output directory to save training results")
parser.add_argument(
"--add_datetime", default=True, type=str, help="If True, add datetime subfolder under output_path"
)
# model
parser.add_argument("--perceptual_loss_weight", default=0.1, type=float, help="perceptual (lpips) loss weight")
parser.add_argument("--kl_loss_weight", default=1.0e-6, type=float, help="KL loss weight")
parser.add_argument(
"--use_outlier_penalty_loss",
default=False,
type=bool,
help="use outlier penalty loss",
)
# training hyper-params
parser.add_argument(
"--resume",
default=False,
type=str,
help="It can be a string for path to resume checkpoint, or a bool False for not resuming.(default=False)",
)
parser.add_argument("--optim", default="adamw_re", type=str, help="optimizer")
parser.add_argument(
"--betas",
type=float,
nargs="+",
default=[0.9, 0.999],
help="Specify the [beta1, beta2] parameter for the AdamW optimizer.",
)
parser.add_argument(
"--optim_eps", type=float, default=1e-8, help="Specify the eps parameter for the AdamW optimizer."
)
parser.add_argument(
"--group_strategy",
type=str,
default=None,
help="Grouping strategy for weight decay. If `norm_and_bias`, weight decay filter list is [beta, gamma, bias]. \
If None, filter list is [layernorm, bias], Default: None",
)
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay.")
parser.add_argument("--warmup_steps", default=1000, type=int, help="warmup steps")
parser.add_argument("--start_learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--end_learning_rate", default=1e-7, type=float, help="The end learning rate for Adam.")
parser.add_argument(
"--scale_lr", default=False, type=bool, help="scale base-lr by ngpu * batch_size * n_accumulate"
)
parser.add_argument("--decay_steps", default=0, type=int, help="lr decay steps.")
parser.add_argument("--scheduler", default="cosine_decay", type=str, help="scheduler.")
parser.add_argument("--pre_patchify", default=False, type=bool, help="Training with patchified latent.")

# dataloader params
parser.add_argument("--dataset_sink_mode", default=False, type=bool, help="sink mode")
parser.add_argument("--sink_size", default=-1, type=int, help="dataset sink size. If -1, sink size = dataset size.")
parser.add_argument(
"--epochs",
default=10,
type=int,
help="epochs. If dataset_sink_mode is on, epochs is with respect to dataset sink size. Otherwise, it's w.r.t the dataset size.",
)
parser.add_argument(
"--train_steps", default=-1, type=int, help="If not -1, limit the number of training steps to the set value"
)
parser.add_argument("--init_loss_scale", default=65536, type=float, help="loss scale")
parser.add_argument("--loss_scale_factor", default=2, type=float, help="loss scale factor")
parser.add_argument("--scale_window", default=2000, type=float, help="scale window")
parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="gradient accumulation steps")
# parser.add_argument("--cond_stage_trainable", default=False, type=bool, help="whether text encoder is trainable")
parser.add_argument("--use_ema", default=False, type=bool, help="whether use EMA")
parser.add_argument("--ema_decay", default=0.9999, type=float, help="ema decay ratio")
parser.add_argument("--clip_grad", default=False, type=bool, help="whether apply gradient clipping")
parser.add_argument(
"--vae_keep_gn_fp32",
default=True,
type=bool,
help="whether keep GroupNorm in fp32.",
)
parser.add_argument(
"--vae_keep_updown_fp32",
default=True,
type=bool,
help="whether keep spatial/temporal upsample and downsample in fp32.",
)
parser.add_argument(
"--enable_flash_attention",
default=None,
type=bool,
help="whether to enable flash attention.",
)
parser.add_argument("--drop_overflow_update", default=True, type=bool, help="drop overflow update")
parser.add_argument("--loss_scaler_type", default="dynamic", type=str, help="dynamic or static")
parser.add_argument(
"--max_grad_norm",
default=1.0,
type=float,
help="max gradient norm for clipping, effective when `clip_grad` enabled.",
)
parser.add_argument("--ckpt_save_interval", default=1, type=int, help="save checkpoint every this epochs")
parser.add_argument(
"--ckpt_save_steps",
default=-1,
type=int,
help="save checkpoint every this steps. If -1, use ckpt_save_interval will be used.",
)
parser.add_argument("--ckpt_max_keep", default=10, type=int, help="Maximum number of checkpoints to keep")
parser.add_argument(
"--step_mode",
default=False,
type=bool,
help="whether save ckpt by steps. If False, save ckpt by epochs.",
)
parser.add_argument("--profile", default=False, type=bool, help="Profile or not")
parser.add_argument(
"--log_level",
type=str,
default="logging.INFO",
help="log level, options: logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR",
)
parser.add_argument(
"--log_interval",
default=1,
type=int,
help="log interval in the unit of data sink size.. E.g. if data sink size = 10, log_inteval=2, log every 20 steps",
)
return parser


def parse_args():
parser = parse_train_args()
args = parser.parse_args()

__dir__ = os.path.dirname(os.path.abspath(__file__))
abs_path = os.path.abspath(os.path.join(__dir__, ".."))

# convert to absolute path, necessary for modelarts
args.csv_path = to_abspath(abs_path, args.csv_path)
args.video_folder = to_abspath(abs_path, args.video_folder)
args.output_path = to_abspath(abs_path, args.output_path)
args.pretrained_model_path = to_abspath(abs_path, args.pretrained_model_path)
args.vae_checkpoint = to_abspath(abs_path, args.vae_checkpoint)
print(args)

return args
34 changes: 34 additions & 0 deletions examples/moviegen/configs/inference/moviegen_t2i_256px.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
env:
mode: 0
jit_level: O0
seed: 42
distributed: False
debug: False

model:
name: llama-5B
pretrained_model_path:
enable_flash_attention: True
dtype: bf16

tae:
pretrained: ""
use_tile: True
dtype: bf16

# Inference parameters
num_sampling_steps: 50
sample_method: linear-quadratic
image_size: [ 256, 256 ]
num_frames: 1 # image
text_emb:
ul2_dir:
metaclip_dir:
byt5_dir:
batch_size: 10

# Saving options
output_path: ../../samples # the path is relative to this config
append_timestamp: True
save_format: png
save_latent: False
43 changes: 43 additions & 0 deletions examples/moviegen/configs/tae/train/mixed_256x256x16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# model
pretrained: "models/tae_vae2d.ckpt"

# loss
perceptual_loss_weight: 1.0
kl_loss_weight: 1.e-6
use_outlier_penalty_loss: False # OPL bring no benefit in our experiments
mixed_strategy: "mixed_video_image"
mixed_image_ratio: 0.2

# data
csv_path: "../videocomposer/datasets/webvid5_copy.csv"
folder: "../videocomposer/datasets/webvid5"
sample_stride: 1
sample_n_frames: 16
image_size: 256
crop_size: 256
# flip: True

# training recipe
seed: 42
batch_size: 1
clip_grad: True
max_grad_norm: 1.0
start_learning_rate: 1.e-5
scale_lr: False
weight_decay: 0.

dtype: "fp32"
use_recompute: False

epochs: 2000
ckpt_save_interval: 50
init_loss_scale: 1024.
loss_scaler_type: dynamic

scheduler: "constant"
use_ema: False

output_path: "outputs/tae_train"

# ms settting
jit_level: O0
43 changes: 43 additions & 0 deletions examples/moviegen/configs/tae/train/mixed_256x256x32.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# model
pretrained: "models/tae_vae2d.ckpt"

# loss
perceptual_loss_weight: 1.0
kl_loss_weight: 1.e-6
use_outlier_penalty_loss: False # OPL bring no benefit in our experiments
mixed_strategy: "mixed_video_image"
mixed_image_ratio: 0.2

# data
csv_path: "../videocomposer/datasets/webvid5_copy.csv"
folder: "../videocomposer/datasets/webvid5"
sample_stride: 1
sample_n_frames: 32
image_size: 256
crop_size: 256
# flip: True

# training recipe
seed: 42
batch_size: 1
clip_grad: True
max_grad_norm: 1.0
start_learning_rate: 1.e-5
scale_lr: False
weight_decay: 0.

dtype: "bf16"
use_recompute: True

epochs: 2000
ckpt_save_interval: 50
init_loss_scale: 1024.
loss_scaler_type: dynamic

scheduler: "constant"
use_ema: False

output_path: "outputs/tae_train"

# ms settting
jit_level: O0
Loading