From 0222d6dce97be21ea516f0fcd88bf21168b8f3b8 Mon Sep 17 00:00:00 2001 From: yaosk Date: Mon, 28 Oct 2024 18:46:48 +0800 Subject: [PATCH] add ut for move flags --- dpdata/abacus/scf.py | 4 ++-- dpdata/vasp/poscar.py | 15 +++++++++------ tests/test_vasp_poscar_to_system.py | 15 +++++++++++++++ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/dpdata/abacus/scf.py b/dpdata/abacus/scf.py index d762c4dc..43e65f7e 100644 --- a/dpdata/abacus/scf.py +++ b/dpdata/abacus/scf.py @@ -615,8 +615,8 @@ def make_unlabeled_stru( numerical descriptor file mass : list of float, optional List of atomic masses - move : list of list of bool, optional - List of the move flag of each xyz direction of each atom + move : list of (list of list of bool), optional + List of the move flag of each xyz direction of each atom for each frame velocity : list of list of float, optional List of the velocity of each xyz direction of each atom mag : list of (list of float or float), optional diff --git a/dpdata/vasp/poscar.py b/dpdata/vasp/poscar.py index d323d692..075d8f2b 100644 --- a/dpdata/vasp/poscar.py +++ b/dpdata/vasp/poscar.py @@ -11,7 +11,7 @@ def move_flag_mapper(flag): elif flag == "F": return False else: - raise RuntimeError(f"Invalid selective dynamics flag: {flag}") + raise RuntimeError(f"Invalid move flag: {flag}") """Treat as cartesian poscar.""" system = {} @@ -35,8 +35,13 @@ def move_flag_mapper(flag): else: tmpv = np.matmul(np.array(tmpv), system["cells"][0]) coord.append(tmpv) - if selective_dynamics and len(tmp) == 6: - move_flags.append(list(map(move_flag_mapper, tmp[3:]))) + if selective_dynamics: + if len(tmp) == 6: + move_flags.append(list(map(move_flag_mapper, tmp[3:]))) + else: + raise RuntimeError( + f"Invalid move flags, should be 6 columns, got {tmp}" + ) system["coords"] = [np.array(coord)] system["orig"] = np.zeros(3) @@ -118,11 +123,9 @@ def from_system_data(system, f_idx=0, skip_zeros=True): move_flags = move[idx] if isinstance(move_flags, list) and len(move_flags) == 3: line += " " + " ".join(["T" if flag else "F" for flag in move_flags]) - elif isinstance(move_flags, (int, float, bool)): - line += " " + " ".join(["T" if move_flags else "F"] * 3) else: raise RuntimeError( - f"Invalid move flags: {move_flags}, should be a list or a bool" + f"Invalid move flags: {move_flags}, should be a list of 3 bools" ) posi_list.append(line) diff --git a/tests/test_vasp_poscar_to_system.py b/tests/test_vasp_poscar_to_system.py index 155f5588..8d642a2e 100644 --- a/tests/test_vasp_poscar_to_system.py +++ b/tests/test_vasp_poscar_to_system.py @@ -19,6 +19,21 @@ def test_move_flags(self): self.assertTrue(np.array_equal(self.system["move"], expected)) +class TestPOSCARCart(unittest.TestCase): + def test_move_flags_error1(self): + with self.assertRaisesRegex(RuntimeError, "Invalid move flags.*?"): + dpdata.System().from_vasp_poscar(os.path.join("poscars", "POSCAR.oh.err1")) + + def test_move_flags_error2(self): + with self.assertRaisesRegex(RuntimeError, "Invalid move flag: a"): + dpdata.System().from_vasp_poscar(os.path.join("poscars", "POSCAR.oh.err2")) + + def test_move_flags_error3(self): + system = dpdata.System().from_vasp_poscar(os.path.join("poscars", "POSCAR.oh.c")) + system.data["move"] = np.array([[[True, True], [False, False]]]) + with self.assertRaisesRegex(RuntimeError, "Invalid move flags:.*?should be a list of 3 bools"): + system.to_vasp_poscar("POSCAR.tmp.1") + class TestPOSCARDirect(unittest.TestCase, TestPOSCARoh): def setUp(self): self.system = dpdata.System()