diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 13e69561..7755e651 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -448,9 +448,9 @@ def __init__( ) if taut is None: - taut = 100 * timestep * units.fs + taut = 100 * timestep if taup is None: - taup = 1000 * timestep * units.fs + taup = 1000 * timestep if ensemble.lower() == "nve": """ @@ -483,6 +483,7 @@ def __init__( atoms=self.atoms, timestep=timestep * units.fs, temperature_K=temperature, + ttime=taut * units.fs, externalstress=None, pfactor=None, trajectory=trajectory, @@ -550,13 +551,14 @@ def __init__( see: https://gitlab.com/ase/ase/-/blob/master/ase/md/npt.py ASE implementation currently only supports upper triangular lattice """ + ptime = taup * units.fs self.dyn = NPT( atoms=self.atoms, timestep=timestep * units.fs, temperature_K=temperature, externalstress=pressure * units.GPa, ttime=taut * units.fs, - pfactor=bulk_modulus * units.GPa * taup * taup, + pfactor=bulk_modulus * units.GPa * ptime * ptime, trajectory=trajectory, logfile=logfile, loginterval=loginterval, diff --git a/tests/test_md.py b/tests/test_md.py index 79aacc78..b417a7f2 100644 --- a/tests/test_md.py +++ b/tests/test_md.py @@ -4,6 +4,7 @@ import pickle from typing import TYPE_CHECKING, Literal +import numpy as np import pytest from ase import Atoms from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen @@ -60,7 +61,7 @@ def test_md_nvt( logfile="md_out.log", loginterval=10, ) - md.run(10) + md.run(100) assert isinstance(md.atoms, Atoms) assert isinstance(md.atoms.calc, CHGNetCalculator) @@ -68,12 +69,25 @@ def test_md_nvt( assert os.path.isfile("md_out.traj") assert os.path.isfile("md_out.log") with open("md_out.log") as log_file: + next(log_file) logs = log_file.read() - assert logs == ( - "Time[ps] Etot[eV] Epot[eV] Ekin[eV] T[K]\n" + logs = np.fromstring(logs, dtype=float, sep=" ") + ref = np.fromstring( "0.0000 -58.9727 -58.9727 0.0000 0.0\n" "0.0200 -58.9723 -58.9731 0.0009 0.8\n" + "0.0400 -58.9672 -58.9727 0.0055 5.4\n" + "0.0600 -58.9427 -58.9663 0.0235 22.8\n" + "0.0800 -58.8605 -58.9352 0.0747 72.2\n" + "0.1000 -58.7651 -58.8438 0.0786 76.0\n" + "0.1200 -58.6684 -58.7268 0.0584 56.4\n" + "0.1400 -58.5703 -58.6202 0.0499 48.2\n" + "0.1600 -58.4724 -58.5531 0.0807 78.1\n" + "0.1800 -58.3891 -58.8077 0.4186 404.8\n" + "0.2000 -58.3398 -58.9244 0.5846 565.4\n", + dtype=float, + sep=" ", ) + assert np.isclose(logs, ref, rtol=2.1e-3, atol=1e-8).all() def test_md_nve(tmp_path: Path, monkeypatch: MonkeyPatch): @@ -118,7 +132,7 @@ def test_md_npt_inhomogeneous_berendsen(tmp_path: Path, monkeypatch: MonkeyPatch logfile="md_out.log", loginterval=10, ) - md.run(10) + md.run(100) assert isinstance(md.atoms, Atoms) assert isinstance(md.atoms.calc, CHGNetCalculator) @@ -128,12 +142,25 @@ def test_md_npt_inhomogeneous_berendsen(tmp_path: Path, monkeypatch: MonkeyPatch assert os.path.isfile("md_out.traj") assert os.path.isfile("md_out.log") with open("md_out.log") as log_file: + next(log_file) logs = log_file.read() - assert logs == ( - "Time[ps] Etot[eV] Epot[eV] Ekin[eV] T[K]\n" + logs = np.fromstring(logs, dtype=float, sep=" ") + ref = np.fromstring( "0.0000 -58.9727 -58.9727 0.0000 0.0\n" - "0.0200 -58.9723 -58.9732 0.0009 0.8\n" + "0.0200 -58.9723 -58.9731 0.0009 0.8\n" + "0.0400 -58.9672 -58.9727 0.0055 5.3\n" + "0.0600 -58.9427 -58.9663 0.0235 22.7\n" + "0.0800 -58.8605 -58.9352 0.0747 72.2\n" + "0.1000 -58.7652 -58.8438 0.0786 76.0\n" + "0.1200 -58.6686 -58.7269 0.0584 56.4\n" + "0.1400 -58.5707 -58.6205 0.0499 48.2\n" + "0.1600 -58.4731 -58.5533 0.0802 77.6\n" + "0.1800 -58.3897 -58.8064 0.4167 402.9\n" + "0.2000 -58.3404 -58.9253 0.5849 565.6\n", + dtype=float, + sep=" ", ) + assert np.isclose(logs, ref, rtol=2.1e-3, atol=1e-8).all() def test_md_crystal_feas_log( @@ -152,7 +179,7 @@ def test_md_crystal_feas_log( crystal_feas_logfile="md_crystal_feas.p", loginterval=1, ) - md.run(10) + md.run(100) assert os.path.isfile("md_crystal_feas.p") with open("md_crystal_feas.p", "rb") as file: @@ -161,9 +188,9 @@ def test_md_crystal_feas_log( crystal_feas = data["crystal_feas"] assert isinstance(crystal_feas, list) - assert len(crystal_feas) == 11 + assert len(crystal_feas) == 101 assert len(crystal_feas[0]) == 64 - assert crystal_feas[0][0] == approx(1.4411175, rel=1e-5) - assert crystal_feas[0][1] == approx(2.6527007, rel=1e-5) - assert crystal_feas[10][0] == approx(1.4390144, rel=1e-5) - assert crystal_feas[10][1] == approx(2.65252, rel=1e-5) + assert crystal_feas[0][0] == approx(1.4411131, rel=1e-5) + assert crystal_feas[0][1] == approx(2.652704, rel=1e-5) + assert crystal_feas[10][0] == approx(1.4390125, rel=1e-5) + assert crystal_feas[10][1] == approx(2.6525214, rel=1e-5)