From 1f5aba1ed3292e0fe67d023920bf0100c5ce47a2 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Sat, 24 Aug 2024 16:09:04 +0800 Subject: [PATCH] download lpips auto --- .../opensora/models/vae/lpips.py | 22 ++++----------- .../opensora_hpcai/opensora/models/vae/vae.py | 15 ++++------ examples/opensora_hpcai/requirements.txt | 1 + examples/opensora_hpcai/scripts/train_vae.py | 10 +++---- mindone/utils/params.py | 28 ++++++++++++++++--- 5 files changed, 41 insertions(+), 35 deletions(-) diff --git a/examples/opensora_hpcai/opensora/models/vae/lpips.py b/examples/opensora_hpcai/opensora/models/vae/lpips.py index aa54031d85..53c91c4ba3 100644 --- a/examples/opensora_hpcai/opensora/models/vae/lpips.py +++ b/examples/opensora_hpcai/opensora/models/vae/lpips.py @@ -6,9 +6,10 @@ import mindspore as ms import mindspore.nn as nn import mindspore.ops as ops +from mindone.utils.params import load_from_pretrained -_logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) class LPIPS(nn.Cell): # Learned perceptual metric @@ -22,7 +23,7 @@ def __init__(self, use_dropout=True): self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) # load NetLin metric layers - self.load_from_pretrained() + self.load_lpips() self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] self.lins = nn.CellList(self.lins) @@ -34,21 +35,10 @@ def __init__(self, use_dropout=True): for param in self.trainable_params(): param.requires_grad = False - def load_from_pretrained(self, ckpt_path="models/lpips_vgg-426bf45c.ckpt"): - # TODO: just load ms ckpt + def load_lpips(self, ckpt_path="models/lpips_vgg-426bf45c.ckpt"): if not os.path.exists(ckpt_path): - raise ValueError( - f"{ckpt_path} not exists. Please download from https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt and move it to models/." - ) - - state_dict = ms.load_checkpoint(ckpt_path) - m, u = ms.load_param_into_net(self, state_dict) - if len(m) > 0: - print("missing keys:") - print(m) - if len(u) > 0: - print("unexpected keys:") - print(u) + ckpt_path = "https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt" + load_from_pretrained(self, ckpt_path) _logger.info("loaded pretrained LPIPS loss from {}".format(ckpt_path)) diff --git a/examples/opensora_hpcai/opensora/models/vae/vae.py b/examples/opensora_hpcai/opensora/models/vae/vae.py index f93be0e496..ae69828942 100644 --- a/examples/opensora_hpcai/opensora/models/vae/vae.py +++ b/examples/opensora_hpcai/opensora/models/vae/vae.py @@ -130,7 +130,6 @@ def encode(self, x): if self.micro_batch_size is None: x_out = self.module.encode(x) * self.scale_factor else: - bs = self.micro_batch_size x_out = self.module.encode(x[:bs]) * self.scale_factor for i in range(bs, x.shape[0], bs): @@ -153,7 +152,7 @@ def decode(self, x, **kwargs): if self.micro_batch_size is None: x_out = self.module.decode(x / self.scale_factor) else: - mbs = self.micro_batch_size + mbs = self.micro_batch_size x_out = self.module.decode(x[:mbs] / self.scale_factor) for i in range(mbs, x.shape[0], mbs): @@ -288,7 +287,6 @@ def encode(self, x): return (z_out - self.shift) / self.scale - def decode(self, z, num_frames=None): if not self.cal_loss: z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype) @@ -302,18 +300,18 @@ def decode(self, z, num_frames=None): return x else: # z: (b Z t//4 h w) - ''' + """ z_splits = mint.split(z, self.micro_z_frame_size, 2) x_z_out = tuple(self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames - i*self.micro_frame_size)) for i, z_bs in enumerate(z_splits)) x_z_out = ops.cat(x_z_out, axis=2) - ''' + """ mz = self.micro_z_frame_size - remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size - x_z_out = self.temporal_vae.decode(z[:, :, : mz], num_frames=remain_frames) + remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size + x_z_out = self.temporal_vae.decode(z[:, :, :mz], num_frames=remain_frames) num_frames -= self.micro_frame_size for i in range(mz, z.shape[2], mz): - remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size + remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size x_z_cur = self.temporal_vae.decode(z[:, :, i : i + mz], num_frames=remain_frames) x_z_out = ops.cat((x_z_out, x_z_cur), axis=2) num_frames -= self.micro_frame_size @@ -325,7 +323,6 @@ def decode(self, z, num_frames=None): else: return x - def construct(self, x): # assert self.cal_loss, "This method is only available when cal_loss is True" z, posterior_mean, posterior_logvar, x_z = self.encode(x) diff --git a/examples/opensora_hpcai/requirements.txt b/examples/opensora_hpcai/requirements.txt index 2980b29c41..94107856a9 100644 --- a/examples/opensora_hpcai/requirements.txt +++ b/examples/opensora_hpcai/requirements.txt @@ -18,3 +18,4 @@ tokenizers sentencepiece transformers pyav +mindcv diff --git a/examples/opensora_hpcai/scripts/train_vae.py b/examples/opensora_hpcai/scripts/train_vae.py index 4e5c15e15a..31f75af6de 100644 --- a/examples/opensora_hpcai/scripts/train_vae.py +++ b/examples/opensora_hpcai/scripts/train_vae.py @@ -20,9 +20,9 @@ from args_train_vae import parse_args from opensora.datasets.vae_dataset import create_dataloader +from opensora.models.layers.operation_selector import set_dynamic_mode from opensora.models.vae.losses import GeneratorWithLoss from opensora.models.vae.vae import OpenSoraVAE_V1_2 -from opensora.models.layers.operation_selector import set_dynamic_mode from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback from mindone.trainers.checkpoint import CheckpointManager, resume_train_network @@ -134,7 +134,6 @@ def init_env( # only effective in GE mode, i.e. jit_level: O2 ms.set_context(ascend_config={"precision_mode": "allow_mix_precision_bf16"}) - if dynamic_shape: print("Dynamic shape mode enabled, repeat_interleave/split/chunk will be called from mint module") set_dynamic_mode(True) @@ -158,7 +157,7 @@ def main(args): parallel_mode=args.parallel_mode, jit_level=args.jit_level, global_bf16=args.global_bf16, - dynamic_shape=(args.mixed_strategy=='mixed_video_random'), + dynamic_shape=(args.mixed_strategy == "mixed_video_random"), debug=args.debug, ) set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level)) @@ -326,15 +325,14 @@ def main(args): ema=ema, ) - # support dynamic shape in graph mode - if args.mode == 0 and args.mixed_strategy == 'mixed_video_random': + # support dynamic shape in graph mode + if args.mode == 0 and args.mixed_strategy == "mixed_video_random": # (b c t h w), drop_remainder so bs fixed # videos = ms.Tensor(shape=[args.batch_size, 3, None, image_size, image_size], dtype=ms.float32) videos = ms.Tensor(shape=[None, 3, None, image_size, image_size], dtype=ms.float32) training_step_ae.set_inputs(videos) logger.info("Dynamic inputs are initialized for mixed_video_random training in Graph mode!") - if rank_id == 0: key_info = "Key Settings:\n" + "=" * 50 + "\n" key_info += "\n".join( diff --git a/mindone/utils/params.py b/mindone/utils/params.py index 0fa69b377f..cff045be7a 100644 --- a/mindone/utils/params.py +++ b/mindone/utils/params.py @@ -1,5 +1,6 @@ import copy import os +import re from typing import List, Optional, Union import mindspore as ms @@ -9,6 +10,13 @@ # from mindspore._checkparam import Validator from mindspore.train.serialization import _load_dismatch_prefix_params, _update_param +from mindcv.utils.download import Download + + +def is_url(string): + # Regex to check for URL patterns + url_pattern = re.compile(r'^(http|https|ftp)://') + return bool(url_pattern.match(string)) def load_param_into_net_with_filter( @@ -98,17 +106,29 @@ def load_param_into_net_with_filter( return param_not_load, ckpt_not_load -def load_checkpoint_to_net( +def load_from_pretrained( net: nn.Cell, checkpoint: Union[str, dict], ignore_net_params_not_loaded=False, ensure_all_ckpt_params_loaded=False, + cache_dir: str=None, ): - """ - ignore_net_params_not_loaded: set True for inference if only a part of network needs to be loaded, the flushing net-not-loaded warnings will disappear. - ensure_all_ckpt_params_loaded : set True for inference if you want to ensure no checkpoint param is missed in loading + """ load checkpoint into network. + + Args: + net: network + checkpoint: local file path to checkpoint, or url to download checkpoint, or a dict for network parameters + ignore_net_params_not_loaded: set True for inference if only a part of network needs to be loaded, the flushing net-not-loaded warnings will disappear. + ensure_all_ckpt_params_loaded : set True for inference if you want to ensure no checkpoint param is missed in loading + cache_dir: directory to cache the downloaded checkpoint, only effective when `checkpoint` is a url. """ if isinstance(checkpoint, str): + if is_url(checkpoint): + url = checkpoint + cache_dir = os.path.join(os.path.expanduser("~"), ".mindspore/models") if cache_dir is None else cache_dir + os.makedirs(cache_dir, exist_ok=True) + Download().download_url(url, path=cache_dir) + checkpoint = os.path.join(download_path, os.path.basename(url)) if os.path.exists(checkpoint): param_dict = ms.load_checkpoint(checkpoint) else: