Skip to content

Commit

Permalink
add ASE's traj support (#614)
Browse files Browse the repository at this point in the history
  • Loading branch information
thangckt authored Mar 19, 2024
1 parent 46a8952 commit 4008687
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 2 deletions.
112 changes: 111 additions & 1 deletion dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
try:
import ase.io
from ase.calculators.calculator import PropertyNotImplementedError
from ase.io import Trajectory

if TYPE_CHECKING:
from ase.optimize.optimize import Optimizer
Expand Down Expand Up @@ -43,7 +44,7 @@ def from_system(self, atoms: "ase.Atoms", **kwargs) -> dict:
data dict
"""
symbols = atoms.get_chemical_symbols()
atom_names = list(set(symbols))
atom_names = list(dict.fromkeys(symbols))
atom_numbs = [symbols.count(symbol) for symbol in atom_names]
atom_types = np.array([atom_names.index(symbol) for symbol in symbols]).astype(
int
Expand Down Expand Up @@ -187,6 +188,115 @@ def to_labeled_system(self, data, *args, **kwargs):
return structures


@Format.register("ase/traj")
class ASETrajFormat(Format):
"""Format for the ASE's trajectory format <https://wiki.fysik.dtu.dk/ase/ase/io/trajectory.html#module-ase.io.trajectory>`_ (ase).'
a `traj' contains a sequence of frames, each of which is an `Atoms' object.
"""

def from_system(
self,
file_name: str,
begin: Optional[int] = 0,
end: Optional[int] = None,
step: Optional[int] = 1,
**kwargs,
) -> dict:
"""Read ASE's trajectory file to `System` of multiple frames.
Parameters
----------
file_name : str
ASE's trajectory file
begin : int, optional
begin frame index
end : int, optional
end frame index
step : int, optional
frame index step
**kwargs : dict
other parameters
Returns
-------
dict_frames: dict
a dictionary containing data of multiple frames
"""
traj = Trajectory(file_name)
sub_traj = traj[begin:end:step]
dict_frames = ASEStructureFormat().from_system(sub_traj[0])
for atoms in sub_traj[1:]:
tmp = ASEStructureFormat().from_system(atoms)
dict_frames["cells"] = np.append(dict_frames["cells"], tmp["cells"][0])
dict_frames["coords"] = np.append(dict_frames["coords"], tmp["coords"][0])

## Correct the shape of numpy arrays
dict_frames["cells"] = dict_frames["cells"].reshape(-1, 3, 3)
dict_frames["coords"] = dict_frames["coords"].reshape(len(sub_traj), -1, 3)

return dict_frames

def from_labeled_system(
self,
file_name: str,
begin: Optional[int] = 0,
end: Optional[int] = None,
step: Optional[int] = 1,
**kwargs,
) -> dict:
"""Read ASE's trajectory file to `System` of multiple frames.
Parameters
----------
file_name : str
ASE's trajectory file
begin : int, optional
begin frame index
end : int, optional
end frame index
step : int, optional
frame index step
**kwargs : dict
other parameters
Returns
-------
dict_frames: dict
a dictionary containing data of multiple frames
"""
traj = Trajectory(file_name)
sub_traj = traj[begin:end:step]

## check if the first frame has a calculator
if sub_traj[0].calc is None:
raise ValueError(
"The input trajectory does not contain energies and forces, may not be a labeled system."
)

dict_frames = ASEStructureFormat().from_labeled_system(sub_traj[0])
for atoms in sub_traj[1:]:
tmp = ASEStructureFormat().from_labeled_system(atoms)
dict_frames["cells"] = np.append(dict_frames["cells"], tmp["cells"][0])
dict_frames["coords"] = np.append(dict_frames["coords"], tmp["coords"][0])
dict_frames["energies"] = np.append(
dict_frames["energies"], tmp["energies"][0]
)
dict_frames["forces"] = np.append(dict_frames["forces"], tmp["forces"][0])
if "virials" in tmp.keys() and "virials" in dict_frames.keys():
dict_frames["virials"] = np.append(
dict_frames["virials"], tmp["virials"][0]
)

## Correct the shape of numpy arrays
dict_frames["cells"] = dict_frames["cells"].reshape(-1, 3, 3)
dict_frames["coords"] = dict_frames["coords"].reshape(len(sub_traj), -1, 3)
dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3)
if "virials" in dict_frames.keys():
dict_frames["virials"] = dict_frames["virials"].reshape(-1, 3, 3)

return dict_frames


@Driver.register("ase")
class ASEDriver(Driver):
"""ASE Driver.
Expand Down
Binary file added tests/ase_traj/MoS2.traj
Binary file not shown.
3 changes: 3 additions & 0 deletions tests/ase_traj/MoS2/box.raw
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
1.571723287959847148e+01 1.382976769974756167e-14 -4.335347714314398588e-27 -2.135776446678914908e+01 3.745820611488967700e+01 -5.045434369425391968e-16 2.142284677455918376e-24 8.545634906358754423e-15 2.312999999999935952e+01
1.571734059897834079e+01 1.382976314475292595e-14 -4.334590580314998760e-27 -2.135791084402753981e+01 3.745820611488959173e+01 -5.045434369425390982e-16 2.155934654129174673e-24 8.545634906358737068e-15 2.312999999999930978e+01
1.571733836538147244e+01 1.382976369106826543e-14 1.070818208280100644e-26 -2.135790780884721940e+01 3.745820611488941410e+01 -2.741620021286372525e-14 -1.227011684011214072e-23 -4.602381967266446384e-15 2.312999999999919609e+01
Loading

0 comments on commit 4008687

Please sign in to comment.