Skip to content

Commit

Permalink
fix: update run_caly_model_devi by xiaoyang
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzyphysics committed May 31, 2024
1 parent 4e364ec commit 5830700
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 77 deletions.
53 changes: 3 additions & 50 deletions dpgen2/op/run_caly_model_devi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def execute(
- `model_devi`: (`Artifact(List[Path])`) The model deviation. The order of recorded model deviations should be consistent with the order of frames in `traj`.
"""

from deepmd.infer import ( # type: ignore
DeepPot,
calc_model_devi,
Expand Down Expand Up @@ -110,15 +109,15 @@ def execute(
for atoms in atoms_list:
natoms = len(atoms)
dump_str = atoms2lmpdump(atoms, tcount, type_map, ignore=True)
dump_str_dict[natoms].append(dump_str)
dump_str_dict[tcount].append(dump_str)

pbc = np.all(atoms.get_pbc())
coord = atoms.get_positions().reshape(1, -1)
cell = atoms.get_cell().array.reshape(1, -1) if pbc else None
atype = [type_map.index(atom.symbol) for atom in atoms] # type: ignore
devi = calc_model_devi(coord, cell, atype, graphs)[0]
devis_dict[natoms].append(devi)
tcount += 1
devis_dict[tcount].append(devi)
tcount += 1

traj_file_list = []
model_devi_file_list = []
Expand Down Expand Up @@ -226,52 +225,6 @@ def atoms2lmpdump(atoms, struc_idx, type_map, ignore=False):
return dump_str


def parse_traj_deprecated(traj_file):
from ase import ( # type: ignore
Atoms,
)
from ase.build import ( # type: ignore
make_supercell,
)
from ase.io import ( # type: ignore
read,
)

# optimization will at least return one structures in traj file
trajs: List[Atoms] = read(traj_file, index=":", format="traj") # type: ignore

numb_traj = len(trajs)
assert numb_traj >= 1, "traj file is broken."

origin = trajs[0]
if len(origin) == 1:
origin = make_supercell(origin, [[2, 0, 0], [0, 2, 0], [0, 0, 2]])
dis_mtx = origin.get_all_distances(mic=True)
row, col = np.diag_indices_from(dis_mtx)
dis_mtx[row, col] = np.nan
is_reasonable = np.nanmin(dis_mtx) > 0.6

selected_traj: Union[List[Atoms], None] = None
if is_reasonable:
if len(trajs) >= 20:
selected_traj = [trajs[iii] for iii in [4, 9, -10, -5, -1]]
elif 5 <= len(trajs) < 20:
selected_traj = [
trajs[np.random.randint(3, len(trajs) - 1)] for _ in range(4)
]
selected_traj.append(trajs[-1])
elif 3 <= len(trajs) < 5:
selected_traj = [trajs[round((len(trajs) - 1) / 2)]]
selected_traj.append(trajs[-1])
elif len(trajs) == 2:
selected_traj = [trajs[0], trajs[-1]]
else: # len(trajs) == 1
selected_traj = [trajs[0]]
else:
selected_traj = None
return selected_traj


def parse_traj(traj_file):
from ase import ( # type: ignore
Atoms,
Expand Down
204 changes: 177 additions & 27 deletions tests/op/test_run_caly_model_devi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import os
import shutil
import unittest
from ast import (
Slice,
)
from pathlib import (
Path,
)
Expand All @@ -20,15 +17,11 @@
from ase.io import (
write,
)
from dflow import (
Step,
)
from dflow.python import (
OP,
OPIO,
Artifact,
OPIOSign,
PythonOPTemplate,
TransientError,
)

Expand Down Expand Up @@ -70,36 +63,161 @@ def setUp(self):
cell=[[10, 0, 0], [0, 10, 0], [0, 0, 10]],
)
self.atoms_normal_1 = Atoms(
numbers=[1],
scaled_positions=[[0, 0, 0]],
cell=[[10, 0, 0], [0, 10, 0], [0, 0, 10]],
)
self.atoms_normal_3 = Atoms(
numbers=[1, 2, 3],
scaled_positions=[
[0, 0, 0],
[0.5, 0.5, 0.5],
[0.0, 0.0, 0.5],
[0.5, 0.0, 0.5],
],
cell=[[10, 0, 0], [0, 10, 0], [0, 0, 10]],
)
self.atoms_normal_3 = Atoms(

self.atoms_normal_4 = Atoms(
numbers=[1, 2, 3],
scaled_positions=[
[0, 0, 0],
[0.0, 0.0, 0.25],
[0.5, 0.0, 0.6],
],
cell=[[10, 0, 0], [0, 10, 0], [0, 0, 10]],
)
self.atoms_normal_5 = Atoms(
numbers=[1, 2, 3],
scaled_positions=[
[0, 0, 0],
[0.0, 0.0, 0.3],
[0.5, 0.0, 0.6],
],
cell=[[10, 0, 0], [0, 10, 0], [0, 0, 10]],
)
self.atoms_normal_6 = Atoms(
numbers=[1, 2, 3],
scaled_positions=[
[0, 0, 0],
[0.0, 0.0, 0.2],
[0.5, 0.0, 0.7],
],
cell=[[10, 0, 0], [0, 10, 0], [0, 0, 10]],
)
sa6 = self.atoms_normal_6
self.atoms_abnormal_2 = Atoms(
numbers=[1, 2, 3, 3],
scaled_positions=[
[0, 0, 0],
[0.0, 0.0, 0.5],
[0.5, 0.0, 0.5],
[0.5, 0.0, 0.6],
],
cell=[[10, 0, 0], [0, 10, 0], [0, 0, 10]],
)
self.atoms_abnormal = Atoms(
numbers=[1, 2],
scaled_positions=[[0, 0, 0], [0.0, 0.0, 0.0]],
scaled_positions=[[0.00001, 0, 0], [0.00002, 0.0, 0.0]],
cell=[[10, 0, 0], [0, 10, 0], [0, 0, 10]],
)

self.atoms_abnormalpbc = Atoms(
numbers=[1, 2],
scaled_positions=[[0.00001, 0, 0], [0.99999, 0.0, 0.0]],
cell=[[10, 0, 0], [0, 10, 0], [0, 0, 10]],
)
self.traj_file_1 = self.work_dir.joinpath("1.traj")
self.traj_file_2 = self.work_dir.joinpath("2.traj")
self.traj_file_3 = self.work_dir.joinpath("3.traj")
self.traj_file_4 = self.work_dir.joinpath("4.traj")
self.traj_file_5 = self.work_dir.joinpath("5.traj")
self.traj_file_6 = self.work_dir.joinpath("6.traj")
self.traj_file_10 = self.work_dir.joinpath(
"10.traj"
) # in traj 10, test the abnormal configuration.
self.traj_file_20 = self.work_dir.joinpath(
"20.traj"
) # in traj 20, test the mix of normal and abnormal configuration.
self.traj_file_21 = self.work_dir.joinpath(
"21.traj"
) # in traj 21, test the mix of normal and abnormal configuration across pbc.

write(
self.traj_file_1,
[self.atoms_normal_1, self.atoms_normal_2, self.atoms_normal_3],
[self.atoms_normal_1],
format="traj",
)

write(
self.traj_file_2,
[self.atoms_normal_2],
format="traj",
)

write(
self.traj_file_3,
[self.atoms_normal_3, self.atoms_abnormal_2],
format="traj",
)

write(
self.traj_file_4,
[
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
sa6,
self.atoms_abnormal_2,
],
format="traj",
)

write(
self.traj_file_5,
[self.atoms_normal_3],
format="traj",
)

write(
self.traj_file_6,
[self.atoms_normal_3],
format="traj",
)
### The abnormal case ###
write(self.traj_file_10, self.atoms_abnormalpbc, format="traj")

### The mixed case ###
write(
self.traj_file_20, [self.atoms_normal_3, self.atoms_abnormal], format="traj"
)
write(
self.traj_file_21,
[self.atoms_normal_1, self.atoms_abnormalpbc],
format="traj",
)
write(self.traj_file_2, self.atoms_abnormal, format="traj")

self.ref_dump_str = """ITEM: TIMESTEP
1
Expand All @@ -113,9 +231,17 @@ def setUp(self):
1 1 0.0000000000 0.0000000000 0.0000000000 0.0000000000 0.0000000000 0.0000000000
2 2 5.0000000000 5.0000000000 5.0000000000 0.0000000000 0.0000000000 0.0000000000
"""
self.type_map = ["H", "He", "Li"]
self.type_map = ["H", "He", "Li", "Na"]
self.task_name = self.work_dir.joinpath(calypso_task_pattern % 0)
self.traj_dirs = [self.traj_file_1, self.traj_file_2]
self.traj_dirs = [
self.traj_file_1,
self.traj_file_2,
self.traj_file_3,
self.traj_file_4,
self.traj_file_5,
self.traj_file_6,
self.traj_file_10,
]

self.model_1 = self.work_dir.joinpath("model.000.pb")
self.model_2 = self.work_dir.joinpath("model.001.pb")
Expand All @@ -127,17 +253,30 @@ def tearDown(self):
shutil.rmtree(self.work_dir)

def test_00_parse_traj(self):
atoms_list_1 = parse_traj(self.traj_file_1)
self.assertEqual(len(atoms_list_1), 2)
self.assertAlmostEqual(atoms_list_1[-1], self.atoms_normal_3)
atoms_list_3 = parse_traj(self.traj_file_3)
self.assertEqual(len(atoms_list_3), 1) # reasonable results are added

atoms_list_2 = parse_traj(self.traj_file_2)
self.assertTrue(atoms_list_2 is None)
atoms_list_4 = parse_traj(self.traj_file_4)
self.assertEqual(len(atoms_list_4), 4) # reasonable results are added

atoms_list_20 = parse_traj(self.traj_file_20)
atoms_list_21 = parse_traj(self.traj_file_21)

self.assertEqual(len(atoms_list_20), 1) # unreasonable results are omitted
self.assertEqual(
len(atoms_list_21), 1
) # unreasonable results are omitted, if 2, ase ignors dangerous distances across pbc

self.assertAlmostEqual(atoms_list_3[-1], self.atoms_normal_3)
atoms_list_10 = parse_traj(self.traj_file_10)
self.assertTrue(atoms_list_10 is None), self.atoms_abnormal

def test_01_atoms2lmpdump(self):
dump_str = atoms2lmpdump(self.atoms_normal_2, 1, self.type_map)
self.assertEqual(dump_str, self.ref_dump_str)

# @patch("dpgen2.op.run_caly_model_devi.RunCalyModelDevi.import_deepmd_package.calc_model_devi")
# @patch("dpgen2.op.run_caly_model_devi.RunCalyModelDevi.import_deepmd_package.DP")
@unittest.skipIf(x == 1, "deepmd package not exists.")
@patch("deepmd.infer.calc_model_devi")
@patch("deepmd.infer.DeepPot")
Expand All @@ -164,24 +303,35 @@ def side_effect_2(*args, **kwargs):
)
)
# check output
self.assertEqual(len(out["traj"]), 2)
self.assertEqual(len(out["traj"]), 8)
self.assertTrue(
self.task_name / "traj.2.dump" in out["traj"],
self.task_name / "traj.0.dump" in out["traj"],
)
self.assertTrue(
self.task_name / "traj.1.dump" not in out["traj"],
self.task_name / "traj.1.dump" in out["traj"],
)
self.assertTrue(
self.task_name / "traj.3.dump" in out["traj"],
self.task_name / "traj.2.dump" in out["traj"],
)
(
self.assertTrue(
self.task_name / "traj.5.dump" in out["traj"],
),
)
self.assertTrue(
self.task_name / "traj.7.dump" in out["traj"],
)

self.assertEqual(len(out["model_devi"]), 2)
self.assertEqual(len(out["model_devi"]), 8)
self.assertTrue(
self.task_name / "model_devi.2.out" in out["model_devi"],
self.task_name / "model_devi.0.out" in out["model_devi"],
)
self.assertTrue(
self.task_name / "model_devi.1.out" in out["model_devi"],
)
self.assertTrue(
self.task_name / "model_devi.1.out" not in out["model_devi"],
self.task_name / "model_devi.2.out" in out["model_devi"],
)
self.assertTrue(
self.task_name / "model_devi.3.out" in out["model_devi"],
self.task_name / "model_devi.7.out" in out["model_devi"],
)

0 comments on commit 5830700

Please sign in to comment.