Skip to content

Commit

Permalink
fix virial in HybridDriver (#604)
Browse files Browse the repository at this point in the history
Based on #603

---------

Co-authored-by: robinzyb <[email protected]>
  • Loading branch information
njzjz and robinzyb authored Feb 2, 2024
1 parent e43f00e commit 9a03f77
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 2 additions & 0 deletions dpdata/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def label(self, data: dict) -> dict:
else:
labeled_data["energies"] += lb_data["energies"]
labeled_data["forces"] += lb_data["forces"]
if "virials" in labeled_data and "virials" in lb_data:
labeled_data["virials"] += lb_data["virials"]
return labeled_data


Expand Down
4 changes: 2 additions & 2 deletions tests/comp_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def test_virial(self):
# if len(self.system_1['virials']) == 0:
# self.assertEqual(len(self.system_1['virials']), 0)
# return
if "virials" not in self.system_1:
self.assertFalse("virials" in self.system_2)
if not self.system_1.has_virial():
self.assertFalse(self.system_2.has_virial())
return
np.testing.assert_almost_equal(
self.system_1["virials"],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def setUp(self):
self.system_2 = dpdata.LabeledSystem(
"poscars/deepmd.h2o.md", fmt="deepmd/raw", type_map=["O", "H"]
)
for pp in ("energies", "forces"):
for pp in ("energies", "forces", "virials"):
self.system_2.data[pp][:] = 3.0

self.places = 6
Expand Down

0 comments on commit 9a03f77

Please sign in to comment.