Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 5, 2024
1 parent ff2ede6 commit b8c454c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 28 deletions.
30 changes: 16 additions & 14 deletions dpdata/plugins/deepmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,36 @@
import dpdata.deepmd.hdf5
import dpdata.deepmd.mixed
import dpdata.deepmd.raw
from dpdata.data_type import Axis, DataType
from dpdata.driver import Driver
from dpdata.format import Format

from dpdata.data_type import Axis, DataType

if TYPE_CHECKING:
import h5py


def register_spin():
dt = DataType(
"spins",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="spin",
)
"spins",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="spin",
)
dpdata.System.register_data_type(dt)
dpdata.LabeledSystem.register_data_type(dt)

dt = DataType(
"force_mags",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="force_mag",
)
"force_mags",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="force_mag",
)
dpdata.System.register_data_type(dt)
dpdata.LabeledSystem.register_data_type(dt)


@Format.register("deepmd")
@Format.register("deepmd/raw")
class DeePMDRawFormat(Format):
Expand Down Expand Up @@ -230,6 +231,7 @@ def _from_system(
file_name is not str or h5py.Group or h5py.File
"""
import h5py

register_spin()

if isinstance(file_name, (h5py.Group, h5py.File)):
Expand Down
33 changes: 19 additions & 14 deletions tests/test_deepmd_spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,45 @@
import shutil
import unittest

import numpy as np
from context import dpdata


class TestDeepmdReadSpinNPY(unittest.TestCase):
def setUp(self):
self.tmp_save_path = "tmp.deepmd.spin/dump-tmp"

def tearDown(self):
if os.path.exists(self.tmp_save_path):
shutil.rmtree(self.tmp_save_path)

def check_Fe16(self, system):
self.assertTrue("spins" in system.data)
self.assertTrue("force_mags" in system.data)
self.assertEqual(system.data["spins"].shape, (2, 16, 3))
self.assertEqual(system.data["force_mags"].shape, (2, 16, 3))

def test_read_spin_npy(self):
system = dpdata.LabeledSystem("tmp.deepmd.spin/Fe16-npy", fmt="deepmd/npy")
self.check_Fe16(system)

system.to( "deepmd/npy",self.tmp_save_path)
self.assertTrue(os.path.isfile(os.path.join(self.tmp_save_path, "set.000/spin.npy")))
self.assertTrue(os.path.isfile(os.path.join(self.tmp_save_path, "set.000/force_mag.npy")))


system.to("deepmd/npy", self.tmp_save_path)
self.assertTrue(
os.path.isfile(os.path.join(self.tmp_save_path, "set.000/spin.npy"))
)
self.assertTrue(
os.path.isfile(os.path.join(self.tmp_save_path, "set.000/force_mag.npy"))
)

def test_read_spin_raw(self):
system = dpdata.LabeledSystem("tmp.deepmd.spin/Fe16-raw", fmt="deepmd/raw")
self.check_Fe16(system)
system.to( "deepmd/raw",self.tmp_save_path)

system.to("deepmd/raw", self.tmp_save_path)
self.assertTrue(os.path.isfile(os.path.join(self.tmp_save_path, "spin.raw")))
self.assertTrue(os.path.isfile(os.path.join(self.tmp_save_path, "force_mag.raw")))


self.assertTrue(
os.path.isfile(os.path.join(self.tmp_save_path, "force_mag.raw"))
)


if __name__ == "__main__":
unittest.main()

0 comments on commit b8c454c

Please sign in to comment.