Skip to content

Commit

Permalink
ENH: PR 4873 revisions
Browse files Browse the repository at this point in the history
* Expand `TPRReader()` support to include velocity handling,
and add tests/functionality for an additional tpx version (`133`).

[ci skip] [skip azp]
  • Loading branch information
tylerjereddy committed Dec 30, 2024
1 parent 4af047d commit 7583a2e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 8 deletions.
6 changes: 5 additions & 1 deletion package/MDAnalysis/coordinates/TPR.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TPRReader(base.SingleFrameReaderBase):
# or perhaps combine the topology and coordinate reading
# with some inheritance shenanigans?
format = "TPR"
units = {"length": "nm"}
units = {"length": "nm", "velocity": "nm/ps"}
_Timestep = Timestep

def _read_first_frame(self):
Expand Down Expand Up @@ -71,3 +71,7 @@ def _read_first_frame(self):
if th.bX:
self.ts._pos = np.asarray(tpr_utils.ndo_rvec(data, th.natoms),
dtype=np.float32)
if th.bV:
self.ts._velocities = np.asarray(tpr_utils.ndo_rvec(data, th.natoms),
dtype=np.float32)
self.ts.has_velocities = True
2 changes: 1 addition & 1 deletion package/MDAnalysis/topology/tpr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def do_mtop(data, fver, tpr_resid_from_one=False):
# src/gromacs/fileio/tpxio.cpp
# TODO: expand tpx version support for striding to
# the coordinates
if fver == 134:
if fver >= 133:
# TODO: the following value is important, and not sure
# how to access programmatically yet...
# from GMX source code:
Expand Down
38 changes: 32 additions & 6 deletions testsuite/MDAnalysisTests/coordinates/test_tpr.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,60 @@
from MDAnalysisTests.datafiles import (TPR2024_4_bonded,
TPR_EXTRA_2024_4,
TPR2024_4)
TPR2024_4,
TPR2024)
import MDAnalysis as mda


import pytest
import numpy as np
from numpy.testing import assert_allclose, assert_equal


@pytest.mark.parametrize("tpr_file, exp_first_atom, exp_last_atom, exp_shape", [
(TPR2024_4_bonded,
@pytest.mark.parametrize("tpr_file, exp_first_atom, exp_last_atom, exp_shape, exp_vel_first_atom, exp_vel_last_atom", [
(TPR2024_4_bonded, # tpx 134
[4.446, 4.659, 2.384],
[4.446, 4.659, 2.384],
(14, 3),
np.zeros(3),
np.zeros(3),
),
# same coordinates, different shape
(TPR_EXTRA_2024_4,
(TPR_EXTRA_2024_4, # tpx 134
[4.446, 4.659, 2.384],
[4.446, 4.659, 2.384],
(18, 3),
np.zeros(3),
np.zeros(3),
),
# different coordinates and different shape
(TPR2024_4,
(TPR2024_4, # tpx 134
[3.25000e-01, 1.00400e+00, 1.03800e+00],
[-2.56000e-01, 1.37300e+00, 3.59800e+00],
(2263, 3),
np.zeros(3),
np.zeros(3),
),
(TPR2024, # tpx 133
[3.25000e-01, 1.00400e+00, 1.03800e+00],
[-2.56000e-01, 1.37300e+00, 3.59800e+00],
(2263, 3),
np.zeros(3),
np.zeros(3),
),
])
def test_basic_read_tpr(tpr_file, exp_first_atom, exp_last_atom, exp_shape):
def test_basic_read_tpr(tpr_file,
exp_first_atom,
exp_last_atom,
exp_shape,
exp_vel_first_atom,
exp_vel_last_atom):
# verify basic ability to read positions and
# velocities from GMX .tpr files
# expected values are from gmx dump
u = mda.Universe(tpr_file)
assert_allclose(u.atoms.positions[0, ...], exp_first_atom)
assert_allclose(u.atoms.positions[-1, ...], exp_last_atom)
assert_equal(u.atoms.positions.shape, exp_shape)
assert_allclose(u.atoms.velocities[0, ...], exp_vel_first_atom)
assert_allclose(u.atoms.velocities[-1, ...], exp_vel_last_atom)
assert_equal(u.atoms.velocities.shape, exp_shape)

0 comments on commit 7583a2e

Please sign in to comment.