Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 172 additions & 1 deletion dpdata/lammps/dump.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
from __future__ import annotations

import itertools
import os
import sys
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -175,7 +176,177 @@ def box2dumpbox(orig, box):
return bounds, tilt


def load_file(fname: FileType, begin=0, step=1):
def get_frame_nlines(fname: FileType):
"""
Determine the number of lines per frame in a LAMMPS dump file.

Parameters
----------
fname : FileType
The dump file name

Returns
-------
int
Number of lines per frame
"""
with open_file(fname) as fp:
frame_start = None
line_count = 0

while True:
line = fp.readline()
if not line:
break
line_count += 1

if "ITEM: TIMESTEP" in line:
if frame_start is None:
frame_start = line_count
else:
# Found the start of the second frame
return line_count - frame_start

# If we only have one frame, return the total line count
return line_count


def read_frames(fname: FileType, f_idx: list[int]):
"""
Efficiently read only specified frames from a LAMMPS dump file.

Parameters
----------
fname : FileType
The dump file name
f_idx : list[int]
List of frame indices to read (0-based)

Returns
-------
list[str]
Lines for the requested frames
"""
if not f_idx:
return []

# Sort frame indices for efficient sequential reading
sorted_indices = sorted(set(f_idx))
nlines = get_frame_nlines(fname)

lines = []
with open_file(fname) as fp:
frame_idx = 0
target_idx = 0

# Use itertools.zip_longest to read frames in blocks
while target_idx < len(sorted_indices):
# Read a frame block
frame_lines = []
for _ in range(nlines):
line = fp.readline()
if not line:
return lines # End of file
frame_lines.append(line.rstrip("\n"))

# Check if this is a frame we want
if frame_idx == sorted_indices[target_idx]:
lines.extend(frame_lines)
target_idx += 1

frame_idx += 1

# Skip ahead if the next target frame is far away
if target_idx < len(sorted_indices):
frames_to_skip = sorted_indices[target_idx] - frame_idx
if frames_to_skip > 0:
# Skip frames by reading and discarding lines
for _ in range(frames_to_skip * nlines):
line = fp.readline()
if not line:
return lines
frame_idx += frames_to_skip

return lines


def load_frames_from_trajectories(frames_dict, **kwargs):
"""
Load frames from multiple trajectory files efficiently.

This implements the pattern described in the issue:
frames_dict = {
Trajectory0: [23, 56, 78],
Trajectory1: [22],
...
}

Parameters
----------
frames_dict : dict
Dictionary mapping trajectory file paths to lists of frame indices
**kwargs
Additional arguments passed to system_data (e.g., type_map, unwrap, input_file)

Returns
-------
dict
Combined system data from all requested frames
"""
combined_data = None

for traj_file, f_idx in frames_dict.items():
if not f_idx:
continue

# Read specific frames from this trajectory
lines = read_frames(traj_file, f_idx)
if not lines:
continue

# Convert to system data
data = system_data(lines, **kwargs)

if combined_data is None:
combined_data = data.copy()
else:
# Append data from this trajectory
combined_data["cells"] = np.concatenate([combined_data["cells"], data["cells"]], axis=0)
combined_data["coords"] = np.concatenate([combined_data["coords"], data["coords"]], axis=0)

if "spins" in combined_data and "spins" in data:
combined_data["spins"] = np.concatenate([combined_data["spins"], data["spins"]], axis=0)
elif "spins" in data:
combined_data["spins"] = data["spins"]

return combined_data if combined_data is not None else {}


def load_file(fname: FileType, begin=0, step=1, f_idx: list[int] = None):
"""
Load frames from a LAMMPS dump file.

Parameters
----------
fname : FileType
The dump file name
begin : int, optional
The begin frame index (used when f_idx is None)
step : int, optional
The step between frames (used when f_idx is None)
f_idx : list[int], optional
Specific frame indices to load. If provided, begin and step are ignored.

Returns
-------
list[str]
Lines for the requested frames
"""
if f_idx is not None:
# Use efficient frame reading for specific indices
return read_frames(fname, f_idx)

# Original implementation for begin/step reading
lines = []
buff = []
cc = -1
Expand Down
5 changes: 4 additions & 1 deletion dpdata/plugins/lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def from_system(
step: int = 1,
unwrap: bool = False,
input_file: str = None,
f_idx: list[int] = None,
**kwargs,
):
"""Read the data from a lammps dump file.
Expand All @@ -87,13 +88,15 @@ def from_system(
Whether to unwrap the coordinates
input_file : str, optional
The input file name
f_idx : list[int], optional
Specific frame indices to load. If provided, begin and step are ignored.

Returns
-------
dict
The system data
"""
lines = dpdata.lammps.dump.load_file(file_name, begin=begin, step=step)
lines = dpdata.lammps.dump.load_file(file_name, begin=begin, step=step, f_idx=f_idx)
data = dpdata.lammps.dump.system_data(
lines, type_map, unwrap=unwrap, input_file=input_file
)
Expand Down
144 changes: 144 additions & 0 deletions tests/test_lammps_dump_efficient_read.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python3
"""Test efficient frame reading functionality for LAMMPS dump files."""

from __future__ import annotations

import os
import unittest

import numpy as np
from comp_sys import CompSys, IsPBC
from context import dpdata
import dpdata.lammps.dump as dump


class TestLAMMPSDumpEfficientRead(unittest.TestCase, CompSys, IsPBC):
def setUp(self):
self.dump_file = os.path.join("poscars", "conf.dump")
self.type_map = ["O", "H"]
self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 4

# Set up comparison systems for inherited tests
# Use the new efficient method as system_1
self.system_1 = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map, f_idx=[0])
# Use traditional method as system_2
self.system_2 = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map).sub_system([0])

def test_get_frame_nlines(self):
"""Test frame line count detection."""
nlines = dump.get_frame_nlines(self.dump_file)
self.assertEqual(nlines, 11) # Expected based on file structure

def test_read_frames_single(self):
"""Test reading a single frame."""
lines = dump.read_frames(self.dump_file, [1])
self.assertEqual(len(lines), 11)
self.assertTrue(lines[0].startswith("ITEM: TIMESTEP"))
self.assertEqual(lines[1], "1") # Second frame has timestep 1

def test_read_frames_multiple(self):
"""Test reading multiple frames."""
lines = dump.read_frames(self.dump_file, [0, 1])
self.assertEqual(len(lines), 22) # 11 lines per frame * 2 frames

def test_read_frames_out_of_order(self):
"""Test reading frames in non-sequential order."""
lines1 = dump.read_frames(self.dump_file, [1, 0])
lines2 = dump.read_frames(self.dump_file, [0, 1])
self.assertEqual(len(lines1), len(lines2))

def test_read_frames_empty(self):
"""Test reading with empty frame list."""
lines = dump.read_frames(self.dump_file, [])
self.assertEqual(len(lines), 0)

def test_load_file_with_f_idx(self):
"""Test enhanced load_file with f_idx parameter."""
# Load specific frame
lines = dump.load_file(self.dump_file, f_idx=[1])
self.assertEqual(len(lines), 11)

# Load multiple frames
lines = dump.load_file(self.dump_file, f_idx=[0, 1])
self.assertEqual(len(lines), 22)

# Test that f_idx overrides begin/step
lines = dump.load_file(self.dump_file, begin=1, step=1, f_idx=[0])
self.assertEqual(len(lines), 11)

def test_system_with_f_idx(self):
"""Test dpdata.System with f_idx parameter."""
# Load all frames for comparison
system_all = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map)

# Load only second frame
system_f1 = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map, f_idx=[1])

self.assertEqual(len(system_all.data["coords"]), 2)
self.assertEqual(len(system_f1.data["coords"]), 1)

# Check that the frame data matches
np.testing.assert_array_almost_equal(
system_all.data["coords"][1],
system_f1.data["coords"][0]
)
np.testing.assert_array_almost_equal(
system_all.data["cells"][1],
system_f1.data["cells"][0]
)

def test_load_frames_from_trajectories(self):
"""Test the frames_dict pattern."""
frames_dict = {
self.dump_file: [0, 1]
}

data = dump.load_frames_from_trajectories(frames_dict, type_map=self.type_map)

self.assertIn("coords", data)
self.assertIn("cells", data)
self.assertEqual(len(data["coords"]), 2)
self.assertEqual(len(data["cells"]), 2)

def test_load_frames_from_trajectories_single(self):
"""Test the frames_dict pattern with single frame."""
frames_dict = {
self.dump_file: [1]
}

data = dump.load_frames_from_trajectories(frames_dict, type_map=self.type_map)

self.assertIn("coords", data)
self.assertIn("cells", data)
self.assertEqual(len(data["coords"]), 1)
self.assertEqual(len(data["cells"]), 1)

def test_efficiency_comparison(self):
"""Compare efficiency by verifying we get the same results."""
# Traditional approach: load all then filter
system_traditional = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map)
filtered_traditional = system_traditional.sub_system([1])

# New efficient approach: load only frame 1
system_efficient = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map, f_idx=[1])

# Results should be identical
np.testing.assert_array_almost_equal(
filtered_traditional.data["coords"][0],
system_efficient.data["coords"][0]
)
np.testing.assert_array_almost_equal(
filtered_traditional.data["cells"][0],
system_efficient.data["cells"][0]
)

def setUp_comp_sys(self):
"""Set up comparison systems for inherited tests."""
pass # Already set up in setUp


if __name__ == "__main__":
unittest.main()