Skip to content

Commit

Permalink
download lpips auto
Browse files Browse the repository at this point in the history
  • Loading branch information
SamitHuang committed Aug 24, 2024
1 parent fac4272 commit 1f5aba1
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 35 deletions.
22 changes: 6 additions & 16 deletions examples/opensora_hpcai/opensora/models/vae/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindone.utils.params import load_from_pretrained

_logger = logging.getLogger(__name__)

_logger = logging.getLogger(__name__)

class LPIPS(nn.Cell):
# Learned perceptual metric
Expand All @@ -22,7 +23,7 @@ def __init__(self, use_dropout=True):
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
# load NetLin metric layers
self.load_from_pretrained()
self.load_lpips()

self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
self.lins = nn.CellList(self.lins)
Expand All @@ -34,21 +35,10 @@ def __init__(self, use_dropout=True):
for param in self.trainable_params():
param.requires_grad = False

def load_from_pretrained(self, ckpt_path="models/lpips_vgg-426bf45c.ckpt"):
# TODO: just load ms ckpt
def load_lpips(self, ckpt_path="models/lpips_vgg-426bf45c.ckpt"):
if not os.path.exists(ckpt_path):
raise ValueError(
f"{ckpt_path} not exists. Please download from https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt and move it to models/."
)

state_dict = ms.load_checkpoint(ckpt_path)
m, u = ms.load_param_into_net(self, state_dict)
if len(m) > 0:
print("missing keys:")
print(m)
if len(u) > 0:
print("unexpected keys:")
print(u)
ckpt_path = "https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt"
load_from_pretrained(self, ckpt_path)

_logger.info("loaded pretrained LPIPS loss from {}".format(ckpt_path))

Expand Down
15 changes: 6 additions & 9 deletions examples/opensora_hpcai/opensora/models/vae/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def encode(self, x):
if self.micro_batch_size is None:
x_out = self.module.encode(x) * self.scale_factor
else:

bs = self.micro_batch_size
x_out = self.module.encode(x[:bs]) * self.scale_factor
for i in range(bs, x.shape[0], bs):
Expand All @@ -153,7 +152,7 @@ def decode(self, x, **kwargs):
if self.micro_batch_size is None:
x_out = self.module.decode(x / self.scale_factor)
else:
mbs = self.micro_batch_size
mbs = self.micro_batch_size

x_out = self.module.decode(x[:mbs] / self.scale_factor)
for i in range(mbs, x.shape[0], mbs):
Expand Down Expand Up @@ -288,7 +287,6 @@ def encode(self, x):

return (z_out - self.shift) / self.scale


def decode(self, z, num_frames=None):
if not self.cal_loss:
z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype)
Expand All @@ -302,18 +300,18 @@ def decode(self, z, num_frames=None):
return x
else:
# z: (b Z t//4 h w)
'''
"""
z_splits = mint.split(z, self.micro_z_frame_size, 2)
x_z_out = tuple(self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames - i*self.micro_frame_size)) for i, z_bs in enumerate(z_splits))
x_z_out = ops.cat(x_z_out, axis=2)
'''
"""
mz = self.micro_z_frame_size
remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size
x_z_out = self.temporal_vae.decode(z[:, :, : mz], num_frames=remain_frames)
remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size
x_z_out = self.temporal_vae.decode(z[:, :, :mz], num_frames=remain_frames)
num_frames -= self.micro_frame_size

for i in range(mz, z.shape[2], mz):
remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size
remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size
x_z_cur = self.temporal_vae.decode(z[:, :, i : i + mz], num_frames=remain_frames)
x_z_out = ops.cat((x_z_out, x_z_cur), axis=2)
num_frames -= self.micro_frame_size
Expand All @@ -325,7 +323,6 @@ def decode(self, z, num_frames=None):
else:
return x


def construct(self, x):
# assert self.cal_loss, "This method is only available when cal_loss is True"
z, posterior_mean, posterior_logvar, x_z = self.encode(x)
Expand Down
1 change: 1 addition & 0 deletions examples/opensora_hpcai/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ tokenizers
sentencepiece
transformers
pyav
mindcv
10 changes: 4 additions & 6 deletions examples/opensora_hpcai/scripts/train_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

from args_train_vae import parse_args
from opensora.datasets.vae_dataset import create_dataloader
from opensora.models.layers.operation_selector import set_dynamic_mode
from opensora.models.vae.losses import GeneratorWithLoss
from opensora.models.vae.vae import OpenSoraVAE_V1_2
from opensora.models.layers.operation_selector import set_dynamic_mode

from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback
from mindone.trainers.checkpoint import CheckpointManager, resume_train_network
Expand Down Expand Up @@ -134,7 +134,6 @@ def init_env(
# only effective in GE mode, i.e. jit_level: O2
ms.set_context(ascend_config={"precision_mode": "allow_mix_precision_bf16"})


if dynamic_shape:
print("Dynamic shape mode enabled, repeat_interleave/split/chunk will be called from mint module")
set_dynamic_mode(True)
Expand All @@ -158,7 +157,7 @@ def main(args):
parallel_mode=args.parallel_mode,
jit_level=args.jit_level,
global_bf16=args.global_bf16,
dynamic_shape=(args.mixed_strategy=='mixed_video_random'),
dynamic_shape=(args.mixed_strategy == "mixed_video_random"),
debug=args.debug,
)
set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level))
Expand Down Expand Up @@ -326,15 +325,14 @@ def main(args):
ema=ema,
)

# support dynamic shape in graph mode
if args.mode == 0 and args.mixed_strategy == 'mixed_video_random':
# support dynamic shape in graph mode
if args.mode == 0 and args.mixed_strategy == "mixed_video_random":
# (b c t h w), drop_remainder so bs fixed
# videos = ms.Tensor(shape=[args.batch_size, 3, None, image_size, image_size], dtype=ms.float32)
videos = ms.Tensor(shape=[None, 3, None, image_size, image_size], dtype=ms.float32)
training_step_ae.set_inputs(videos)
logger.info("Dynamic inputs are initialized for mixed_video_random training in Graph mode!")


if rank_id == 0:
key_info = "Key Settings:\n" + "=" * 50 + "\n"
key_info += "\n".join(
Expand Down
28 changes: 24 additions & 4 deletions mindone/utils/params.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import os
import re
from typing import List, Optional, Union

import mindspore as ms
Expand All @@ -9,6 +10,13 @@

# from mindspore._checkparam import Validator
from mindspore.train.serialization import _load_dismatch_prefix_params, _update_param
from mindcv.utils.download import Download


def is_url(string):
# Regex to check for URL patterns
url_pattern = re.compile(r'^(http|https|ftp)://')
return bool(url_pattern.match(string))


def load_param_into_net_with_filter(
Expand Down Expand Up @@ -98,17 +106,29 @@ def load_param_into_net_with_filter(
return param_not_load, ckpt_not_load


def load_checkpoint_to_net(
def load_from_pretrained(
net: nn.Cell,
checkpoint: Union[str, dict],
ignore_net_params_not_loaded=False,
ensure_all_ckpt_params_loaded=False,
cache_dir: str=None,
):
"""
ignore_net_params_not_loaded: set True for inference if only a part of network needs to be loaded, the flushing net-not-loaded warnings will disappear.
ensure_all_ckpt_params_loaded : set True for inference if you want to ensure no checkpoint param is missed in loading
""" load checkpoint into network.
Args:
net: network
checkpoint: local file path to checkpoint, or url to download checkpoint, or a dict for network parameters
ignore_net_params_not_loaded: set True for inference if only a part of network needs to be loaded, the flushing net-not-loaded warnings will disappear.
ensure_all_ckpt_params_loaded : set True for inference if you want to ensure no checkpoint param is missed in loading
cache_dir: directory to cache the downloaded checkpoint, only effective when `checkpoint` is a url.
"""
if isinstance(checkpoint, str):
if is_url(checkpoint):
url = checkpoint
cache_dir = os.path.join(os.path.expanduser("~"), ".mindspore/models") if cache_dir is None else cache_dir
os.makedirs(cache_dir, exist_ok=True)
Download().download_url(url, path=cache_dir)
checkpoint = os.path.join(download_path, os.path.basename(url))
if os.path.exists(checkpoint):
param_dict = ms.load_checkpoint(checkpoint)
else:
Expand Down

0 comments on commit 1f5aba1

Please sign in to comment.