diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index 48d8ca68..c07e3ca9 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Optional, Type, Generator, List import numpy as np @@ -114,7 +114,7 @@ def from_multi_systems( step: Optional[int] = None, ase_fmt: Optional[str] = None, **kwargs, - ) -> object: # generator of "ase.Atoms" + ) -> Generator["ase.Atoms"]: """Convert a ASE supported file to ASE Atoms. It will finally be converted to MultiSystems. @@ -142,7 +142,7 @@ def from_multi_systems( frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step)) yield from frames - def to_system(self, data, **kwargs): + def to_system(self, data, **kwargs) -> List["ase.Atoms"]: """Convert System to ASE Atom obj.""" from ase import Atoms @@ -160,7 +160,7 @@ def to_system(self, data, **kwargs): return structures - def to_labeled_system(self, data, *args, **kwargs): + def to_labeled_system(self, data, *args, **kwargs) -> List["ase.Atoms"]: """Convert System to ASE Atoms object.""" from ase import Atoms from ase.calculators.singlepoint import SinglePointCalculator @@ -298,20 +298,35 @@ def from_labeled_system( return dict_frames - def to_system(self, data, **kwargs): - """Convert System to ASE Atoms object.""" + def to_system(self, + data, + file_name: str = "confs.traj", + **kwargs) -> None: + """Convert System to ASE Atoms object. + + Parameters + ---------- + file_name : str + path to file + """ list_atoms = ASEStructureFormat().to_system(data, **kwargs) - file_name = kwargs.get("file_name", "conf.traj") - traj = Trajectory(file_name, "a") + traj = Trajectory(file_name, 'a') _ = [traj.write(atom) for atom in list_atoms] traj.close() return - def to_labeled_system(self, data, *args, **kwargs): - """Convert System to ASE Atoms object.""" + def to_labeled_system(self, + data, + file_name: str = "labeled_confs.traj", + *args, **kwargs) -> None: + """Convert System to ASE Atoms object. + Parameters + ---------- + file_name : str + path to file + """ list_atoms = ASEStructureFormat().to_labeled_system(data, *args, **kwargs) - file_name = kwargs.get("file_name", "labeled_conf.traj") - traj = Trajectory(file_name, "a") + traj = Trajectory(file_name, 'a') _ = [traj.write(atom) for atom in list_atoms] traj.close() return