Skip to content

Commit

Permalink
fix dl.load cannot hanle yaml file
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Dec 19, 2024
1 parent c706dd1 commit aca825d
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions danling/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,16 @@ def save(obj: Any, file: PathStr, *args: List[Any], **kwargs: Dict[str, Any]) ->
obj.json(file)
else:
with open(file, "w") as fp:
json.dump(obj, fp, *args, **kwargs) # type: ignore
json.dump(obj, fp, *args, **kwargs) # type: ignore[arg-type]
elif extension in YAML:
if isinstance(obj, FlatDict):
obj.yaml(file)
else:
with open(file, "w") as fp:
yaml.dump(obj, fp, *args, **kwargs) # type: ignore
yaml.dump(obj, fp, *args, **kwargs) # type: ignore[arg-type, call-overload]
elif extension in PICKLE:
with open(file, "wb") as fp:
pickle.dump(obj, fp, *args, **kwargs) # type: ignore
pickle.dump(obj, fp, *args, **kwargs) # type: ignore[arg-type]
else:
raise ValueError(f"Tying to save {obj} to {file!r} with unsupported extension={extension!r}")
return file
Expand All @@ -135,13 +135,14 @@ def load(file: PathStr, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
return numpy.load(file, *args, **kwargs)
if extension in JSON:
with open(file) as fp:
return json.load(fp, *args, **kwargs) # type: ignore
return json.load(fp, *args, **kwargs) # type: ignore[arg-type]
if extension in YAML:
with open(file) as fp:
return yaml.load(fp, *args, **kwargs) # type: ignore
kwargs.setdefault("Loader", yaml.FullLoader) # type: ignore[arg-type]
return yaml.load(fp, *args, **kwargs) # type: ignore[arg-type]
if extension in PICKLE:
with open(file, "rb") as fp:
return pickle.load(fp, *args, **kwargs) # type: ignore
return pickle.load(fp, *args, **kwargs) # type: ignore[arg-type]
if extension in PANDAS_SUPPORTED:
return load_pandas(file, *args, **kwargs)
raise ValueError(f"Tying to load {file!r} with unsupported extension={extension!r}")
Expand Down

0 comments on commit aca825d

Please sign in to comment.