diff --git a/dpdata/lammps/dump.py b/dpdata/lammps/dump.py index fe549b95..3f998d1c 100644 --- a/dpdata/lammps/dump.py +++ b/dpdata/lammps/dump.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 from __future__ import annotations +import itertools import os import sys from typing import TYPE_CHECKING @@ -175,7 +176,177 @@ def box2dumpbox(orig, box): return bounds, tilt -def load_file(fname: FileType, begin=0, step=1): +def get_frame_nlines(fname: FileType): + """ + Determine the number of lines per frame in a LAMMPS dump file. + + Parameters + ---------- + fname : FileType + The dump file name + + Returns + ------- + int + Number of lines per frame + """ + with open_file(fname) as fp: + frame_start = None + line_count = 0 + + while True: + line = fp.readline() + if not line: + break + line_count += 1 + + if "ITEM: TIMESTEP" in line: + if frame_start is None: + frame_start = line_count + else: + # Found the start of the second frame + return line_count - frame_start + + # If we only have one frame, return the total line count + return line_count + + +def read_frames(fname: FileType, f_idx: list[int]): + """ + Efficiently read only specified frames from a LAMMPS dump file. + + Parameters + ---------- + fname : FileType + The dump file name + f_idx : list[int] + List of frame indices to read (0-based) + + Returns + ------- + list[str] + Lines for the requested frames + """ + if not f_idx: + return [] + + # Sort frame indices for efficient sequential reading + sorted_indices = sorted(set(f_idx)) + nlines = get_frame_nlines(fname) + + lines = [] + with open_file(fname) as fp: + frame_idx = 0 + target_idx = 0 + + # Use itertools.zip_longest to read frames in blocks + while target_idx < len(sorted_indices): + # Read a frame block + frame_lines = [] + for _ in range(nlines): + line = fp.readline() + if not line: + return lines # End of file + frame_lines.append(line.rstrip("\n")) + + # Check if this is a frame we want + if frame_idx == sorted_indices[target_idx]: + lines.extend(frame_lines) + target_idx += 1 + + frame_idx += 1 + + # Skip ahead if the next target frame is far away + if target_idx < len(sorted_indices): + frames_to_skip = sorted_indices[target_idx] - frame_idx + if frames_to_skip > 0: + # Skip frames by reading and discarding lines + for _ in range(frames_to_skip * nlines): + line = fp.readline() + if not line: + return lines + frame_idx += frames_to_skip + + return lines + + +def load_frames_from_trajectories(frames_dict, **kwargs): + """ + Load frames from multiple trajectory files efficiently. + + This implements the pattern described in the issue: + frames_dict = { + Trajectory0: [23, 56, 78], + Trajectory1: [22], + ... + } + + Parameters + ---------- + frames_dict : dict + Dictionary mapping trajectory file paths to lists of frame indices + **kwargs + Additional arguments passed to system_data (e.g., type_map, unwrap, input_file) + + Returns + ------- + dict + Combined system data from all requested frames + """ + combined_data = None + + for traj_file, f_idx in frames_dict.items(): + if not f_idx: + continue + + # Read specific frames from this trajectory + lines = read_frames(traj_file, f_idx) + if not lines: + continue + + # Convert to system data + data = system_data(lines, **kwargs) + + if combined_data is None: + combined_data = data.copy() + else: + # Append data from this trajectory + combined_data["cells"] = np.concatenate([combined_data["cells"], data["cells"]], axis=0) + combined_data["coords"] = np.concatenate([combined_data["coords"], data["coords"]], axis=0) + + if "spins" in combined_data and "spins" in data: + combined_data["spins"] = np.concatenate([combined_data["spins"], data["spins"]], axis=0) + elif "spins" in data: + combined_data["spins"] = data["spins"] + + return combined_data if combined_data is not None else {} + + +def load_file(fname: FileType, begin=0, step=1, f_idx: list[int] = None): + """ + Load frames from a LAMMPS dump file. + + Parameters + ---------- + fname : FileType + The dump file name + begin : int, optional + The begin frame index (used when f_idx is None) + step : int, optional + The step between frames (used when f_idx is None) + f_idx : list[int], optional + Specific frame indices to load. If provided, begin and step are ignored. + + Returns + ------- + list[str] + Lines for the requested frames + """ + if f_idx is not None: + # Use efficient frame reading for specific indices + return read_frames(fname, f_idx) + + # Original implementation for begin/step reading lines = [] buff = [] cc = -1 diff --git a/dpdata/plugins/lammps.py b/dpdata/plugins/lammps.py index c7e5c765..617f4fd8 100644 --- a/dpdata/plugins/lammps.py +++ b/dpdata/plugins/lammps.py @@ -69,6 +69,7 @@ def from_system( step: int = 1, unwrap: bool = False, input_file: str = None, + f_idx: list[int] = None, **kwargs, ): """Read the data from a lammps dump file. @@ -87,13 +88,15 @@ def from_system( Whether to unwrap the coordinates input_file : str, optional The input file name + f_idx : list[int], optional + Specific frame indices to load. If provided, begin and step are ignored. Returns ------- dict The system data """ - lines = dpdata.lammps.dump.load_file(file_name, begin=begin, step=step) + lines = dpdata.lammps.dump.load_file(file_name, begin=begin, step=step, f_idx=f_idx) data = dpdata.lammps.dump.system_data( lines, type_map, unwrap=unwrap, input_file=input_file ) diff --git a/tests/test_lammps_dump_efficient_read.py b/tests/test_lammps_dump_efficient_read.py new file mode 100644 index 00000000..899ec943 --- /dev/null +++ b/tests/test_lammps_dump_efficient_read.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +"""Test efficient frame reading functionality for LAMMPS dump files.""" + +from __future__ import annotations + +import os +import unittest + +import numpy as np +from comp_sys import CompSys, IsPBC +from context import dpdata +import dpdata.lammps.dump as dump + + +class TestLAMMPSDumpEfficientRead(unittest.TestCase, CompSys, IsPBC): + def setUp(self): + self.dump_file = os.path.join("poscars", "conf.dump") + self.type_map = ["O", "H"] + self.places = 6 + self.e_places = 6 + self.f_places = 6 + self.v_places = 4 + + # Set up comparison systems for inherited tests + # Use the new efficient method as system_1 + self.system_1 = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map, f_idx=[0]) + # Use traditional method as system_2 + self.system_2 = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map).sub_system([0]) + + def test_get_frame_nlines(self): + """Test frame line count detection.""" + nlines = dump.get_frame_nlines(self.dump_file) + self.assertEqual(nlines, 11) # Expected based on file structure + + def test_read_frames_single(self): + """Test reading a single frame.""" + lines = dump.read_frames(self.dump_file, [1]) + self.assertEqual(len(lines), 11) + self.assertTrue(lines[0].startswith("ITEM: TIMESTEP")) + self.assertEqual(lines[1], "1") # Second frame has timestep 1 + + def test_read_frames_multiple(self): + """Test reading multiple frames.""" + lines = dump.read_frames(self.dump_file, [0, 1]) + self.assertEqual(len(lines), 22) # 11 lines per frame * 2 frames + + def test_read_frames_out_of_order(self): + """Test reading frames in non-sequential order.""" + lines1 = dump.read_frames(self.dump_file, [1, 0]) + lines2 = dump.read_frames(self.dump_file, [0, 1]) + self.assertEqual(len(lines1), len(lines2)) + + def test_read_frames_empty(self): + """Test reading with empty frame list.""" + lines = dump.read_frames(self.dump_file, []) + self.assertEqual(len(lines), 0) + + def test_load_file_with_f_idx(self): + """Test enhanced load_file with f_idx parameter.""" + # Load specific frame + lines = dump.load_file(self.dump_file, f_idx=[1]) + self.assertEqual(len(lines), 11) + + # Load multiple frames + lines = dump.load_file(self.dump_file, f_idx=[0, 1]) + self.assertEqual(len(lines), 22) + + # Test that f_idx overrides begin/step + lines = dump.load_file(self.dump_file, begin=1, step=1, f_idx=[0]) + self.assertEqual(len(lines), 11) + + def test_system_with_f_idx(self): + """Test dpdata.System with f_idx parameter.""" + # Load all frames for comparison + system_all = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map) + + # Load only second frame + system_f1 = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map, f_idx=[1]) + + self.assertEqual(len(system_all.data["coords"]), 2) + self.assertEqual(len(system_f1.data["coords"]), 1) + + # Check that the frame data matches + np.testing.assert_array_almost_equal( + system_all.data["coords"][1], + system_f1.data["coords"][0] + ) + np.testing.assert_array_almost_equal( + system_all.data["cells"][1], + system_f1.data["cells"][0] + ) + + def test_load_frames_from_trajectories(self): + """Test the frames_dict pattern.""" + frames_dict = { + self.dump_file: [0, 1] + } + + data = dump.load_frames_from_trajectories(frames_dict, type_map=self.type_map) + + self.assertIn("coords", data) + self.assertIn("cells", data) + self.assertEqual(len(data["coords"]), 2) + self.assertEqual(len(data["cells"]), 2) + + def test_load_frames_from_trajectories_single(self): + """Test the frames_dict pattern with single frame.""" + frames_dict = { + self.dump_file: [1] + } + + data = dump.load_frames_from_trajectories(frames_dict, type_map=self.type_map) + + self.assertIn("coords", data) + self.assertIn("cells", data) + self.assertEqual(len(data["coords"]), 1) + self.assertEqual(len(data["cells"]), 1) + + def test_efficiency_comparison(self): + """Compare efficiency by verifying we get the same results.""" + # Traditional approach: load all then filter + system_traditional = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map) + filtered_traditional = system_traditional.sub_system([1]) + + # New efficient approach: load only frame 1 + system_efficient = dpdata.System(self.dump_file, fmt="lammps/dump", type_map=self.type_map, f_idx=[1]) + + # Results should be identical + np.testing.assert_array_almost_equal( + filtered_traditional.data["coords"][0], + system_efficient.data["coords"][0] + ) + np.testing.assert_array_almost_equal( + filtered_traditional.data["cells"][0], + system_efficient.data["cells"][0] + ) + + def setUp_comp_sys(self): + """Set up comparison systems for inherited tests.""" + pass # Already set up in setUp + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file