From aca825d9f65da43ac5cff826f92e168c787509fe Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 19 Dec 2024 15:16:51 +0800 Subject: [PATCH] fix dl.load cannot hanle yaml file Signed-off-by: Zhiyuan Chen --- danling/utils/io.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/danling/utils/io.py b/danling/utils/io.py index 8ce6a662..24f2b65d 100644 --- a/danling/utils/io.py +++ b/danling/utils/io.py @@ -103,16 +103,16 @@ def save(obj: Any, file: PathStr, *args: List[Any], **kwargs: Dict[str, Any]) -> obj.json(file) else: with open(file, "w") as fp: - json.dump(obj, fp, *args, **kwargs) # type: ignore + json.dump(obj, fp, *args, **kwargs) # type: ignore[arg-type] elif extension in YAML: if isinstance(obj, FlatDict): obj.yaml(file) else: with open(file, "w") as fp: - yaml.dump(obj, fp, *args, **kwargs) # type: ignore + yaml.dump(obj, fp, *args, **kwargs) # type: ignore[arg-type, call-overload] elif extension in PICKLE: with open(file, "wb") as fp: - pickle.dump(obj, fp, *args, **kwargs) # type: ignore + pickle.dump(obj, fp, *args, **kwargs) # type: ignore[arg-type] else: raise ValueError(f"Tying to save {obj} to {file!r} with unsupported extension={extension!r}") return file @@ -135,13 +135,14 @@ def load(file: PathStr, *args: List[Any], **kwargs: Dict[str, Any]) -> Any: return numpy.load(file, *args, **kwargs) if extension in JSON: with open(file) as fp: - return json.load(fp, *args, **kwargs) # type: ignore + return json.load(fp, *args, **kwargs) # type: ignore[arg-type] if extension in YAML: with open(file) as fp: - return yaml.load(fp, *args, **kwargs) # type: ignore + kwargs.setdefault("Loader", yaml.FullLoader) # type: ignore[arg-type] + return yaml.load(fp, *args, **kwargs) # type: ignore[arg-type] if extension in PICKLE: with open(file, "rb") as fp: - return pickle.load(fp, *args, **kwargs) # type: ignore + return pickle.load(fp, *args, **kwargs) # type: ignore[arg-type] if extension in PANDAS_SUPPORTED: return load_pandas(file, *args, **kwargs) raise ValueError(f"Tying to load {file!r} with unsupported extension={extension!r}")