Skip to content

Commit

Permalink
improve ASE traj (#633)
Browse files Browse the repository at this point in the history
add functions to convert from others formats to ASE traj format

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced system and labeled system handling with new parameters and
functionalities.
  - Improved stress calculations by modifying method parameters.

- **Bug Fixes**
- Corrected return types for several methods to ensure consistency and
reliability.

- **Tests**
- Added new test cases to validate system setups and trajectory file
operations.

- **Chores**
- Updated Python version matrix in workflow to include only version
3.11.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: C. Thang Nguyen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
3 people authored May 29, 2024
1 parent c5b36bb commit 199afc1
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 5 deletions.
51 changes: 46 additions & 5 deletions dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import os
from typing import TYPE_CHECKING, Generator

import numpy as np

Expand Down Expand Up @@ -94,7 +95,7 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict:
"forces": np.array([forces]),
}
try:
stress = atoms.get_stress(False)
stress = atoms.get_stress(voigt=False)
except PropertyNotImplementedError:
pass
else:
Expand All @@ -110,7 +111,7 @@ def from_multi_systems(
step: int | None = None,
ase_fmt: str | None = None,
**kwargs,
) -> ase.Atoms:
) -> Generator[ase.Atoms, None, None]:
"""Convert a ASE supported file to ASE Atoms.
It will finally be converted to MultiSystems.
Expand Down Expand Up @@ -140,7 +141,7 @@ def from_multi_systems(
frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step))
yield from frames

def to_system(self, data, **kwargs):
def to_system(self, data, **kwargs) -> list[ase.Atoms]:
"""Convert System to ASE Atom obj."""
from ase import Atoms

Expand All @@ -158,7 +159,7 @@ def to_system(self, data, **kwargs):

return structures

def to_labeled_system(self, data, *args, **kwargs):
def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]:
"""Convert System to ASE Atoms object."""
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
Expand Down Expand Up @@ -300,6 +301,46 @@ def from_labeled_system(

return dict_frames

def to_system(self, data, file_name: str = "confs.traj", **kwargs) -> None:
"""Convert System to ASE Atoms object.
Parameters
----------
file_name : str
path to file
"""
from ase.io import Trajectory

if os.path.isfile(file_name):
os.remove(file_name)

list_atoms = ASEStructureFormat().to_system(data, **kwargs)
traj = Trajectory(file_name, "a")
_ = [traj.write(atom) for atom in list_atoms]
traj.close()
return

def to_labeled_system(
self, data, file_name: str = "labeled_confs.traj", *args, **kwargs
) -> None:
"""Convert System to ASE Atoms object.
Parameters
----------
file_name : str
path to file
"""
from ase.io import Trajectory

if os.path.isfile(file_name):
os.remove(file_name)

list_atoms = ASEStructureFormat().to_labeled_system(data, *args, **kwargs)
traj = Trajectory(file_name, "a")
_ = [traj.write(atom) for atom in list_atoms]
traj.close()
return


@Driver.register("ase")
class ASEDriver(Driver):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_ase_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,29 @@ def setUp(self):
self.v_places = 4


@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix")
class TestASEtraj4(unittest.TestCase, CompSys, IsPBC):
def setUp(self):
self.system_1 = dpdata.System("ase_traj/MoS2", fmt="deepmd")
self.system_1.to(file_name="ase_traj/tmp.traj", fmt="ase/traj")
self.system_2 = dpdata.System("ase_traj/tmp.traj", fmt="ase/traj")
self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 4


@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix")
class TestASEtraj4Labeled(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp(self):
self.system_1 = dpdata.LabeledSystem("ase_traj/MoS2", fmt="deepmd")
self.system_1.to(file_name="ase_traj/tmp1.traj", fmt="ase/traj")
self.system_2 = dpdata.LabeledSystem("ase_traj/tmp1.traj", fmt="ase/traj")
self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 4


if __name__ == "__main__":
unittest.main()

0 comments on commit 199afc1

Please sign in to comment.