Skip to content

Commit

Permalink
1. Fixed units of taut, taup 2. Added missing ttime in NPT (#67)
Browse files Browse the repository at this point in the history
* 1. Fixed units of taut, taup 2. Added missing ttime in NPT

* Updated test_md.py to reflect the changes of taut and taup

* Reformatted to confront with pre-commit

---------

Co-authored-by: Ziyang HU <[email protected]>
  • Loading branch information
tsihyoung and Ziyang HU authored Sep 7, 2023
1 parent 5c07964 commit 4475871
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 16 deletions.
8 changes: 5 additions & 3 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,9 @@ def __init__(
)

if taut is None:
taut = 100 * timestep * units.fs
taut = 100 * timestep
if taup is None:
taup = 1000 * timestep * units.fs
taup = 1000 * timestep

if ensemble.lower() == "nve":
"""
Expand Down Expand Up @@ -483,6 +483,7 @@ def __init__(
atoms=self.atoms,
timestep=timestep * units.fs,
temperature_K=temperature,
ttime=taut * units.fs,
externalstress=None,
pfactor=None,
trajectory=trajectory,
Expand Down Expand Up @@ -550,13 +551,14 @@ def __init__(
see: https://gitlab.com/ase/ase/-/blob/master/ase/md/npt.py
ASE implementation currently only supports upper triangular lattice
"""
ptime = taup * units.fs
self.dyn = NPT(
atoms=self.atoms,
timestep=timestep * units.fs,
temperature_K=temperature,
externalstress=pressure * units.GPa,
ttime=taut * units.fs,
pfactor=bulk_modulus * units.GPa * taup * taup,
pfactor=bulk_modulus * units.GPa * ptime * ptime,
trajectory=trajectory,
logfile=logfile,
loginterval=loginterval,
Expand Down
53 changes: 40 additions & 13 deletions tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pickle
from typing import TYPE_CHECKING, Literal

import numpy as np
import pytest
from ase import Atoms
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen
Expand Down Expand Up @@ -60,20 +61,33 @@ def test_md_nvt(
logfile="md_out.log",
loginterval=10,
)
md.run(10)
md.run(100)

assert isinstance(md.atoms, Atoms)
assert isinstance(md.atoms.calc, CHGNetCalculator)
assert isinstance(md.dyn, NVTBerendsen)
assert os.path.isfile("md_out.traj")
assert os.path.isfile("md_out.log")
with open("md_out.log") as log_file:
next(log_file)
logs = log_file.read()
assert logs == (
"Time[ps] Etot[eV] Epot[eV] Ekin[eV] T[K]\n"
logs = np.fromstring(logs, dtype=float, sep=" ")
ref = np.fromstring(
"0.0000 -58.9727 -58.9727 0.0000 0.0\n"
"0.0200 -58.9723 -58.9731 0.0009 0.8\n"
"0.0400 -58.9672 -58.9727 0.0055 5.4\n"
"0.0600 -58.9427 -58.9663 0.0235 22.8\n"
"0.0800 -58.8605 -58.9352 0.0747 72.2\n"
"0.1000 -58.7651 -58.8438 0.0786 76.0\n"
"0.1200 -58.6684 -58.7268 0.0584 56.4\n"
"0.1400 -58.5703 -58.6202 0.0499 48.2\n"
"0.1600 -58.4724 -58.5531 0.0807 78.1\n"
"0.1800 -58.3891 -58.8077 0.4186 404.8\n"
"0.2000 -58.3398 -58.9244 0.5846 565.4\n",
dtype=float,
sep=" ",
)
assert np.isclose(logs, ref, rtol=2.1e-3, atol=1e-8).all()


def test_md_nve(tmp_path: Path, monkeypatch: MonkeyPatch):
Expand Down Expand Up @@ -118,7 +132,7 @@ def test_md_npt_inhomogeneous_berendsen(tmp_path: Path, monkeypatch: MonkeyPatch
logfile="md_out.log",
loginterval=10,
)
md.run(10)
md.run(100)

assert isinstance(md.atoms, Atoms)
assert isinstance(md.atoms.calc, CHGNetCalculator)
Expand All @@ -128,12 +142,25 @@ def test_md_npt_inhomogeneous_berendsen(tmp_path: Path, monkeypatch: MonkeyPatch
assert os.path.isfile("md_out.traj")
assert os.path.isfile("md_out.log")
with open("md_out.log") as log_file:
next(log_file)
logs = log_file.read()
assert logs == (
"Time[ps] Etot[eV] Epot[eV] Ekin[eV] T[K]\n"
logs = np.fromstring(logs, dtype=float, sep=" ")
ref = np.fromstring(
"0.0000 -58.9727 -58.9727 0.0000 0.0\n"
"0.0200 -58.9723 -58.9732 0.0009 0.8\n"
"0.0200 -58.9723 -58.9731 0.0009 0.8\n"
"0.0400 -58.9672 -58.9727 0.0055 5.3\n"
"0.0600 -58.9427 -58.9663 0.0235 22.7\n"
"0.0800 -58.8605 -58.9352 0.0747 72.2\n"
"0.1000 -58.7652 -58.8438 0.0786 76.0\n"
"0.1200 -58.6686 -58.7269 0.0584 56.4\n"
"0.1400 -58.5707 -58.6205 0.0499 48.2\n"
"0.1600 -58.4731 -58.5533 0.0802 77.6\n"
"0.1800 -58.3897 -58.8064 0.4167 402.9\n"
"0.2000 -58.3404 -58.9253 0.5849 565.6\n",
dtype=float,
sep=" ",
)
assert np.isclose(logs, ref, rtol=2.1e-3, atol=1e-8).all()


def test_md_crystal_feas_log(
Expand All @@ -152,7 +179,7 @@ def test_md_crystal_feas_log(
crystal_feas_logfile="md_crystal_feas.p",
loginterval=1,
)
md.run(10)
md.run(100)

assert os.path.isfile("md_crystal_feas.p")
with open("md_crystal_feas.p", "rb") as file:
Expand All @@ -161,9 +188,9 @@ def test_md_crystal_feas_log(
crystal_feas = data["crystal_feas"]

assert isinstance(crystal_feas, list)
assert len(crystal_feas) == 11
assert len(crystal_feas) == 101
assert len(crystal_feas[0]) == 64
assert crystal_feas[0][0] == approx(1.4411175, rel=1e-5)
assert crystal_feas[0][1] == approx(2.6527007, rel=1e-5)
assert crystal_feas[10][0] == approx(1.4390144, rel=1e-5)
assert crystal_feas[10][1] == approx(2.65252, rel=1e-5)
assert crystal_feas[0][0] == approx(1.4411131, rel=1e-5)
assert crystal_feas[0][1] == approx(2.652704, rel=1e-5)
assert crystal_feas[10][0] == approx(1.4390125, rel=1e-5)
assert crystal_feas[10][1] == approx(2.6525214, rel=1e-5)

0 comments on commit 4475871

Please sign in to comment.