diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index 1d818483..3ee35c28 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -62,6 +62,11 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict: """Convert ase.Atoms to a LabeledSystem. Energies and forces are calculated by the calculator. + Note that this method will try to load virials from the following sources: + - atoms.info['virial'] + - atoms.info['virials'] + - converted from stress tensor + Parameters ---------- atoms : ase.Atoms @@ -93,13 +98,21 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict: "energies": np.array([energies]), "forces": np.array([forces]), } - try: - stress = atoms.get_stress(False) - except PropertyNotImplementedError: - pass - else: - virials = np.array([-atoms.get_volume() * stress]) + + # try to get virials from different sources + virials = atoms.info.get("virial") + if virials is None: + virials = atoms.info.get("virials") + if virials is None: + try: + stress = atoms.get_stress(False) + except PropertyNotImplementedError: + pass + else: + virials = np.array([-atoms.get_volume() * stress]) + if virials is not None: info_dict["virials"] = virials + return info_dict def from_multi_systems( @@ -165,7 +178,6 @@ def to_labeled_system(self, data, *args, **kwargs): structures = [] species = [data["atom_names"][tt] for tt in data["atom_types"]] - for ii in range(data["coords"].shape[0]): structure = Atoms( symbols=species,