From 7ca7e68c90d538e3c208bb3579a873f58da6538e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 14 Sep 2023 17:18:33 -0400 Subject: [PATCH 1/5] drop the old data type with the same name and throw warning --- dpdata/system.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dpdata/system.py b/dpdata/system.py index 8af8e4a5..65e0a39f 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1,6 +1,7 @@ # %% import glob import os +import warning from copy import deepcopy from typing import Any, Dict, Optional, Tuple, Union @@ -963,7 +964,13 @@ def register_data_type(cls, *data_type: Tuple[DataType]): *data_type : tuple[DataType] data type to be regiestered """ - cls.DTYPES = cls.DTYPES + tuple(data_type) + all_dtypes = cls.DTYPES + tuple(data_type) + dtypes_dict = {} + for dt in all_dtypes: + if dt.name in dtypes_dict: + warnings.warn(f"Data type {dt.name} is registered twice; only the newly registered one will be used.", UserWarning) + dtypes_dict[dt.name] = dt + cls.DTYPES = tuple(dtypes_dict.values()) def get_cell_perturb_matrix(cell_pert_fraction): From 3d0980017fb7fdf52f625da96ad906717688ffe7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Sep 2023 21:22:45 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/system.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dpdata/system.py b/dpdata/system.py index 65e0a39f..990ea7f9 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1,7 +1,6 @@ # %% import glob import os -import warning from copy import deepcopy from typing import Any, Dict, Optional, Tuple, Union @@ -968,7 +967,10 @@ def register_data_type(cls, *data_type: Tuple[DataType]): dtypes_dict = {} for dt in all_dtypes: if dt.name in dtypes_dict: - warnings.warn(f"Data type {dt.name} is registered twice; only the newly registered one will be used.", UserWarning) + warnings.warn( + f"Data type {dt.name} is registered twice; only the newly registered one will be used.", + UserWarning, + ) dtypes_dict[dt.name] = dt cls.DTYPES = tuple(dtypes_dict.values()) From d332bfb5583b319533029379efdb037cec5bd392 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 14 Sep 2023 17:41:22 -0400 Subject: [PATCH 3/5] import warnings --- dpdata/system.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dpdata/system.py b/dpdata/system.py index 990ea7f9..0748db28 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1,6 +1,7 @@ # %% import glob import os +import warnings from copy import deepcopy from typing import Any, Dict, Optional, Tuple, Union From f171c491c48c09031d0d78fb1f04c010bfbe8084 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 15 Sep 2023 17:15:30 -0400 Subject: [PATCH 4/5] add test Signed-off-by: Jinzhe Zeng --- tests/test_custom_data_type.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 5a0e2bab..42d51abf 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -4,6 +4,7 @@ import numpy as np import dpdata +from dpdata.data_type import Axis, DataType class TestDeepmdLoadDumpComp(unittest.TestCase): @@ -44,6 +45,13 @@ def test_from_deepmd_hdf5(self): x = dpdata.LabeledSystem("data_foo.h5", fmt="deepmd/hdf5") np.testing.assert_allclose(x.data["foo"], self.foo) + def test_duplicated_data_type(self): + dt = DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False) + n_dtypes_old = len(dpdata.LabeledSystem.DTYPES) + with self.assertWarns(UserWarning): + dpdata.LabeledSystem.register_data_type(dt) + n_dtypes_new = len(dpdata.LabeledSystem.DTYPES) + self.assertEqual(n_dtypes_old, n_dtypes_new) class TestDeepmdLoadDumpCompAny(unittest.TestCase): def setUp(self): From d84b3de97837bb3f039668c3cf405cea6a7cf104 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 21:15:46 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_custom_data_type.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 42d51abf..006d6b01 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -53,6 +53,7 @@ def test_duplicated_data_type(self): n_dtypes_new = len(dpdata.LabeledSystem.DTYPES) self.assertEqual(n_dtypes_old, n_dtypes_new) + class TestDeepmdLoadDumpCompAny(unittest.TestCase): def setUp(self): self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar")