Skip to content

Commit

Permalink
Merge pull request #13 from thangckt/PR
Browse files Browse the repository at this point in the history
fetch PR
  • Loading branch information
thangckt authored May 27, 2024
2 parents b5a9171 + 62fe922 commit 27cec4d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
36 changes: 22 additions & 14 deletions dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generator, List
import os
from typing import TYPE_CHECKING, Generator

import numpy as np

Expand All @@ -11,7 +12,11 @@
if TYPE_CHECKING:
import ase
from ase.optimize.optimize import Optimizer

try:
from ase.io import Trajectory
except ImportError:
pass


@Format.register("ase/structure")
Expand Down Expand Up @@ -111,7 +116,7 @@ def from_multi_systems(
step: int | None = None,
ase_fmt: str | None = None,
**kwargs,
) -> Generator["ase.Atoms", None, None]:
) -> Generator[ase.Atoms, None, None]:
"""Convert a ASE supported file to ASE Atoms.
It will finally be converted to MultiSystems.
Expand Down Expand Up @@ -141,7 +146,7 @@ def from_multi_systems(
frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step))
yield from frames

def to_system(self, data, **kwargs) -> List["ase.Atoms"]:
def to_system(self, data, **kwargs) -> list[ase.Atoms]:
"""Convert System to ASE Atom obj."""
from ase import Atoms

Expand All @@ -159,7 +164,7 @@ def to_system(self, data, **kwargs) -> List["ase.Atoms"]:

return structures

def to_labeled_system(self, data, *args, **kwargs) -> List["ase.Atoms"]:
def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]:
"""Convert System to ASE Atoms object."""
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
Expand Down Expand Up @@ -301,35 +306,38 @@ def from_labeled_system(

return dict_frames

def to_system(self,
data,
file_name: str = "confs.traj",
**kwargs) -> None:
def to_system(self, data, file_name: str = "confs.traj", **kwargs) -> None:
"""Convert System to ASE Atoms object.
Parameters
----------
file_name : str
path to file
"""
if os.path.isfile(file_name):
os.remove(file_name)

list_atoms = ASEStructureFormat().to_system(data, **kwargs)
traj = Trajectory(file_name, 'a')
traj = Trajectory(file_name, "a")
_ = [traj.write(atom) for atom in list_atoms]
traj.close()
return

def to_labeled_system(self,
data,
file_name: str = "labeled_confs.traj",
*args, **kwargs) -> None:
def to_labeled_system(
self, data, file_name: str = "labeled_confs.traj", *args, **kwargs
) -> None:
"""Convert System to ASE Atoms object.
Parameters
----------
file_name : str
path to file
"""
if os.path.isfile(file_name):
os.remove(file_name)

list_atoms = ASEStructureFormat().to_labeled_system(data, *args, **kwargs)
traj = Trajectory(file_name, 'a')
traj = Trajectory(file_name, "a")
_ = [traj.write(atom) for atom in list_atoms]
traj.close()
return
Expand Down
24 changes: 24 additions & 0 deletions tests/test_ase_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,29 @@ def setUp(self):
self.v_places = 4


@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix")
class TestASEtraj4(unittest.TestCase, CompSys, IsPBC):
def setUp(self):
self.system_1 = dpdata.System("ase_traj/MoS2", fmt="deepmd")
self.system_1.to(file_name="ase_traj/tmp.traj", fmt="ase/traj")
self.system_2 = dpdata.System("ase_traj/tmp.traj", fmt="ase/traj")
self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 4


@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix")
class TestASEtraj4Labeled(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp(self):
self.system_1 = dpdata.LabeledSystem("ase_traj/MoS2", fmt="deepmd")
self.system_1.to(file_name="ase_traj/tmp1.traj", fmt="ase/traj")
self.system_2 = dpdata.LabeledSystem("ase_traj/tmp1.traj", fmt="ase/traj")
self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 4


if __name__ == "__main__":
unittest.main()

0 comments on commit 27cec4d

Please sign in to comment.