diff --git a/examples/opensora_hpcai/opensora/models/vae/lpips.py b/examples/opensora_hpcai/opensora/models/vae/lpips.py index 53c91c4ba3..e8d29363fd 100644 --- a/examples/opensora_hpcai/opensora/models/vae/lpips.py +++ b/examples/opensora_hpcai/opensora/models/vae/lpips.py @@ -6,11 +6,12 @@ import mindspore as ms import mindspore.nn as nn import mindspore.ops as ops -from mindone.utils.params import load_from_pretrained +from mindone.utils.params import load_from_pretrained _logger = logging.getLogger(__name__) + class LPIPS(nn.Cell): # Learned perceptual metric def __init__(self, use_dropout=True): diff --git a/mindone/utils/params.py b/mindone/utils/params.py index cff045be7a..ead1b65b92 100644 --- a/mindone/utils/params.py +++ b/mindone/utils/params.py @@ -3,6 +3,8 @@ import re from typing import List, Optional, Union +from mindcv.utils.download import Download + import mindspore as ms import mindspore.nn as nn from mindspore import Parameter @@ -10,12 +12,11 @@ # from mindspore._checkparam import Validator from mindspore.train.serialization import _load_dismatch_prefix_params, _update_param -from mindcv.utils.download import Download def is_url(string): # Regex to check for URL patterns - url_pattern = re.compile(r'^(http|https|ftp)://') + url_pattern = re.compile(r"^(http|https|ftp)://") return bool(url_pattern.match(string)) @@ -111,9 +112,9 @@ def load_from_pretrained( checkpoint: Union[str, dict], ignore_net_params_not_loaded=False, ensure_all_ckpt_params_loaded=False, - cache_dir: str=None, + cache_dir: str = None, ): - """ load checkpoint into network. + """load checkpoint into network. Args: net: network