Skip to content

Commit

Permalink
add test to FrameAnalysis
Browse files Browse the repository at this point in the history
  • Loading branch information
yuxuanzhuang committed Jan 14, 2025
1 parent 0bd3d54 commit 2c9982b
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions testsuite/MDAnalysisTests/analysis/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,34 @@ def __init__(self, reader, **kwargs):

def _prepare(self):
self.results.found_frames = []
self.results.frame_index = []
self.results.global_frame_index = []
self.results.n_frames = []
self.results.global_n_frames = []

# self.n_frames is defined elsewhere
self.global_n_frames = len(self._trajectory[self._global_slicer])

def _single_frame(self):
frame_index = self._frame_index
global_frame_index = self._global_frame_index

self.results.found_frames.append(self._ts.frame)
self.results.frame_index.append(frame_index)
self.results.global_frame_index.append(global_frame_index)
self.results.n_frames.append(self.n_frames)
self.results.global_n_frames.append(self.global_n_frames)

def _conclude(self):
self.found_frames = list(self.results.found_frames)

def _get_aggregator(self):
return base.ResultsGroup(
{"found_frames": base.ResultsGroup.ndarray_hstack}
{"found_frames": base.ResultsGroup.ndarray_hstack,
"frame_index": base.ResultsGroup.ndarray_hstack,
"global_frame_index": base.ResultsGroup.ndarray_hstack,
"n_frames": base.ResultsGroup.ndarray_hstack,
"global_n_frames": base.ResultsGroup.ndarray_hstack}
)


Expand Down Expand Up @@ -450,12 +468,17 @@ def test_frames_times(client_FrameAnalysis):
start=1, stop=8, step=2, **client_FrameAnalysis
)
frames = np.array([1, 3, 5, 7])
assert an.n_frames == len(frames)
n_frames = len(frames)
frame_indices = np.arange(n_frames)

assert an.n_frames == n_frames
assert_equal(an.found_frames, frames)
assert_equal(an.frames, frames, err_msg=FRAMES_ERR)
assert_allclose(
an.times, frames * 100, rtol=0, atol=1.5e-4, err_msg=TIMES_ERR
)
assert_equal(an.results.global_frame_index, frame_indices)
assert_equal(an.results.global_n_frames, [n_frames] * n_frames)


def test_verbose(u):
Expand Down

0 comments on commit 2c9982b

Please sign in to comment.