diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index c6a595ab..6b80ebdf 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -1,3 +1,4 @@ +import os from __future__ import annotations from typing import TYPE_CHECKING, Generator @@ -313,6 +314,9 @@ def to_system(self, data, file_name: str = "confs.traj", **kwargs) -> None: file_name : str path to file """ + if os.path.isfile(file_name): + os.remove(file_name) + list_atoms = ASEStructureFormat().to_system(data, **kwargs) traj = Trajectory(file_name, "a") _ = [traj.write(atom) for atom in list_atoms] @@ -329,6 +333,9 @@ def to_labeled_system( file_name : str path to file """ + if os.path.isfile(file_name): + os.remove(file_name) + list_atoms = ASEStructureFormat().to_labeled_system(data, *args, **kwargs) traj = Trajectory(file_name, "a") _ = [traj.write(atom) for atom in list_atoms] diff --git a/tests/test_ase_traj.py b/tests/test_ase_traj.py index efe3126d..98bc8bc0 100644 --- a/tests/test_ase_traj.py +++ b/tests/test_ase_traj.py @@ -70,10 +70,9 @@ def setUp(self): @unittest.skipIf(skip_ase, "skip ase related test. install ase to fix") class TestASEtraj4(unittest.TestCase, CompSys, IsPBC): def setUp(self): - system_1 = dpdata.System("ase_traj/MoS2", fmt="deepmd") - system_1.to(file_name="tmp.traj", fmt="ase/traj") - self.system_2 = dpdata.System("tmp.traj", fmt="ase/traj") - self.system_1 = system_1 + self.system_1 = dpdata.System("ase_traj/MoS2", fmt="deepmd") + self.system_1.to(file_name="ase_traj/tmp.traj", fmt="ase/traj") + self.system_2 = dpdata.System("ase_traj/tmp.traj", fmt="ase/traj") self.places = 6 self.e_places = 6 self.f_places = 6 @@ -83,10 +82,9 @@ def setUp(self): @unittest.skipIf(skip_ase, "skip ase related test. install ase to fix") class TestASEtraj4Labeled(unittest.TestCase, CompLabeledSys, IsPBC): def setUp(self): - system_1 = dpdata.LabeledSystem("ase_traj/MoS2", fmt="deepmd") - system_1.to(file_name="tmp1.traj", fmt="ase/traj") - self.system_2 = dpdata.LabeledSystem("tmp1.traj", fmt="ase/traj") - self.system_1 = system_1 + self.system_1 = dpdata.LabeledSystem("ase_traj/MoS2", fmt="deepmd") + self.system_1.to(file_name="ase_traj/tmp1.traj", fmt="ase/traj") + self.system_2 = dpdata.LabeledSystem("ase_traj/tmp1.traj", fmt="ase/traj") self.places = 6 self.e_places = 6 self.f_places = 6