Skip to content

Commit

Permalink
First variant of a ModelrunnerStorage class (#555)
Browse files Browse the repository at this point in the history
This facilitates integrating `py-modelrunner` with `py-pde` without a
tight binding.
  • Loading branch information
david-zwicker authored Apr 13, 2024
1 parent 1eba064 commit cbe668c
Show file tree
Hide file tree
Showing 12 changed files with 412 additions and 72 deletions.
2 changes: 1 addition & 1 deletion pde/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# determine the package version
try:
# try reading version of the automatically generated module
from ._version import __version__ # type: ignore
from ._version import __version__
except ImportError:
# determine version automatically from CVS information
from importlib.metadata import PackageNotFoundError, version
Expand Down
6 changes: 6 additions & 0 deletions pde/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
~memory.get_memory_storage
~memory.MemoryStorage
~modelrunner.ModelrunnerStorage
~file.FileStorage
~movie.MovieStorage
Expand All @@ -15,3 +16,8 @@
from .file import FileStorage
from .memory import MemoryStorage, get_memory_storage
from .movie import MovieStorage

try:
from .modelrunner import ModelrunnerStorage
except ImportError:
... # ModelrunnerStorage is only available when py-modelrunner is available
151 changes: 151 additions & 0 deletions pde/storage/modelrunner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
Defines a class storing data using :mod:`modelrunner`.
.. codeauthor:: David Zwicker <[email protected]>
"""

from __future__ import annotations

import modelrunner as mr
import numpy as np

from ..fields.base import FieldBase
from .base import InfoDict, StorageBase, WriteModeType


class ModelrunnerStorage(StorageBase):
"""store discretized fields in a :mod:`modelrunner` storage"""

def __init__(
self,
storage: mr.storage.StorageGroup,
*,
loc: mr.storage.Location = "trajectory",
info: InfoDict | None = None,
write_mode: WriteModeType = "truncate_once",
):
"""
Args:
storage (:class:`~modelrunner.storage.group.StorageGroup`):
Modelrunner storage used for storing the trajectory
loc (str or list of str):
The location in the storage where the trajectory data is written.
info (dict):
Supplies extra information that is stored in the storage
write_mode (str):
Determines how new data is added to already existing data. Possible
values are: 'append' (data is always appended), 'truncate' (data is
cleared every time this storage is used for writing), or 'truncate_once'
(data is cleared for the first writing, but appended subsequently).
Alternatively, specifying 'readonly' will disable writing completely.
"""
super().__init__(info=info, write_mode=write_mode)
self.storage = storage
self.loc = loc
self._writer: mr.storage.TrajectoryWriter | None = None
self._reader: mr.storage.Trajectory | None = None

def close(self) -> None:
"""close the currently opened trajectory writer"""
if self._writer is not None:
self._writer.close()
self._writer = None

def __enter__(self) -> ModelrunnerStorage:
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.close()

def __len__(self):
"""return the number of stored items, i.e., time steps"""
return len(self.times)

@property
def _io(self) -> mr.storage.TrajectoryWriter | mr.storage.Trajectory:
""":class:`~modelrunner.storage.group.StorageGroup`: Group with all data"""
if self._writer is not None:
return self._writer
if self._reader is None:
self._reader = mr.storage.Trajectory(self.storage, loc=self.loc)
return self._reader

@property
def times(self):
""":class:`~numpy.ndarray`: The times at which data is available"""
return self._io.times

@property
def data(self):
""":class:`~numpy.ndarray`: The actual data for all time"""
return self._io._storage.read_array(self._io._loc + ["data"])

def clear(self, clear_data_shape: bool = False):
"""truncate the storage by removing all stored data.
Args:
clear_data_shape (bool):
Flag determining whether the data shape is also deleted.
"""
if self.loc in self.storage:
raise NotImplementedError("Cannot delete existing trajectory")
super().clear(clear_data_shape=clear_data_shape)

def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None:
"""initialize the storage for writing data
Args:
field (:class:`~pde.fields.FieldBase`):
An example of the data that will be written to extract the grid and the
data_shape
info (dict):
Supplies extra information that is stored in the storage
"""
if self._writer:
raise RuntimeError(f"{self.__class__.__name__} is already in writing mode")
if self._reader:
self._reader.close()

# delete data if truncation is requested. This is for instance necessary
# to remove older data with incompatible data_shape
if self.write_mode == "truncate" or self.write_mode == "truncate_once":
self.clear(clear_data_shape=True)

# initialize the writing, setting current data shape
super().start_writing(field, info=info)

# initialize the file for writing with the correct mode
self._logger.debug(f"Start writing with mode `{self.write_mode}`")
if self.write_mode == "truncate_once":
self.write_mode = "append" # do not truncate for next writing
elif self.write_mode == "readonly":
raise RuntimeError("Cannot write in read-only mode")
elif self.write_mode not in {"truncate", "append"}:
raise ValueError(
f"Unknown write mode `{self.write_mode}`. Possible values are "
"`truncate_once`, `truncate`, and `append`"
)

if info:
self.info.update(info)
self._writer = mr.storage.TrajectoryWriter(
self.storage, loc=self.loc, attrs=self.info, mode="append"
)

def _append_data(self, data: np.ndarray, time: float) -> None:
"""append a new data set
Args:
data (:class:`~numpy.ndarray`): The actual data
time (float): The time point associated with the data
"""
assert self._writer is not None
self._writer.append(data, float(time))

def end_writing(self) -> None:
"""finalize the storage after writing.
This makes sure the data is actually written to a file when
self.keep_opened == False
"""
self.close()
76 changes: 44 additions & 32 deletions pde/storage/movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,51 @@ def times(self):

return times

def _iter_data(self) -> Iterator[np.ndarray]:
"""iterate over all stored fields"""
import ffmpeg # lazy loading so it's not a hard dependence

if "width" not in self.info:
self._read_metadata()
if self._field is None:
self._init_field()
assert self._field is not None
self._init_normalization(self._field)
assert self._norms is not None
frame_shape = (self.info["width"], self.info["height"], self._format.channels)
data_shape = (len(self._norms), self.info["width"], self.info["height"])
data = np.empty(data_shape, dtype=self._dtype)
frame_bytes = np.prod(frame_shape) * self._format.bytes_per_channel

# iterate over entire movie
f_input = ffmpeg.input(self.filename, loglevel=self.loglevel)
f_output = f_input.output(
"pipe:", format="rawvideo", pix_fmt=self._format.pix_fmt_data
)
proc = f_output.run_async(pipe_stdout=True)
while True:
read_bytes = proc.stdout.read(frame_bytes)
if not read_bytes:
break
frame = np.frombuffer(read_bytes, self._format.dtype).reshape(frame_shape)

for i, norm in enumerate(self._norms):
frame_data = self._format.data_from_frame(frame[:, :, i])
data[i, :, :] = norm.inverse(frame_data)

yield data

@property
def data(self):
""":class:`~numpy.ndarray`: The actual data for all time"""
raise NotImplementedError
""":class:`~numpy.ndarray`: The actual data for all times"""
it = self._iter_data() # get the iterater of all data
first_frame = next(it) # get the first frame to obtain necessary information
# allocate memory for all data
data = np.empty((len(self),) + first_frame.shape, dtype=first_frame.dtype)
data[0] = first_frame # set the first frame
for i, frame_data in enumerate(it, 1): # set all subsequent frames
data[i] = frame_data
return data

def _get_field(self, t_index: int) -> FieldBase:
"""return the field corresponding to the given time index
Expand Down Expand Up @@ -533,36 +574,7 @@ def _get_field(self, t_index: int) -> FieldBase:

def __iter__(self) -> Iterator[FieldBase]:
"""iterate over all stored fields"""
import ffmpeg # lazy loading so it's not a hard dependence

if "width" not in self.info:
self._read_metadata()
if self._field is None:
self._init_field()
assert self._field is not None
self._init_normalization(self._field)
assert self._norms is not None
frame_shape = (self.info["width"], self.info["height"], self._format.channels)
data_shape = (len(self._norms), self.info["width"], self.info["height"])
data = np.empty(data_shape, dtype=self._dtype)
frame_bytes = np.prod(frame_shape) * self._format.bytes_per_channel

# iterate over entire movie
f_input = ffmpeg.input(self.filename, loglevel=self.loglevel)
f_output = f_input.output(
"pipe:", format="rawvideo", pix_fmt=self._format.pix_fmt_data
)
proc = f_output.run_async(pipe_stdout=True)
while True:
read_bytes = proc.stdout.read(frame_bytes)
if not read_bytes:
break
frame = np.frombuffer(read_bytes, self._format.dtype).reshape(frame_shape)

for i, norm in enumerate(self._norms):
frame_data = self._format.data_from_frame(frame[:, :, i])
data[i, :, :] = norm.inverse(frame_data)

for data in self._iter_data():
# create the field with the data of the given index
assert self._field is not None
field = self._field.copy()
Expand Down
4 changes: 3 additions & 1 deletion pde/trackers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def from_data(cls, data: TrackerCollectionDataType, **kwargs) -> TrackerCollecti
trackers = [data]
elif isinstance(data, str):
trackers = [TrackerBase.from_data(data, **kwargs)]
else:
elif isinstance(data, (list, tuple)):
# initialize trackers from a sequence
trackers, interrupt_ids = [], set()
for tracker in data:
Expand All @@ -190,6 +190,8 @@ def from_data(cls, data: TrackerCollectionDataType, **kwargs) -> TrackerCollecti
tracker_obj.interrupt = tracker_obj.interrupt.copy()
interrupt_ids.add(id(tracker_obj.interrupt))
trackers.append(tracker_obj)
else:
raise TypeError(f"Cannot initialize trackers from class `{data.__class__}`")

return cls(trackers)

Expand Down
12 changes: 12 additions & 0 deletions scripts/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def run_unit_tests(
coverage: bool = False,
nojit: bool = False,
early: bool = False,
debug: bool = False,
pattern: str = None,
) -> int:
"""run the unit tests
Expand All @@ -133,6 +134,7 @@ def run_unit_tests(
coverage (bool): Whether to determine the test coverage
nojit (bool): Whether to disable numba jit compilation
early (bool): Whether to fail at the first test
debug (bool): Whether extra output useful for debugging should be emitted
pattern (str): A pattern that determines which tests are ran
Returns:
Expand Down Expand Up @@ -172,6 +174,9 @@ def run_unit_tests(
args.append("--runslow") # also run slow tests
if runinteractive:
args.append("--runinteractive") # also run interactive tests
if debug:
# show debug log entries live
args.extend(["-o", "log_cli=true", "--log-cli-level=debug"])
if use_mpi:
try:
import numba_mpi # @UnusedImport
Expand Down Expand Up @@ -306,6 +311,12 @@ def main() -> int:
default=False,
help="Return at first failed test",
)
group.add_argument(
"--debug",
action="store_true",
default=False,
help="Show extra debug output",
)
group.add_argument(
"--pattern",
metavar="PATTERN",
Expand Down Expand Up @@ -346,6 +357,7 @@ def main() -> int:
num_cores=args.num_cores,
nojit=args.nojit,
early=args.early,
debug=args.debug,
pattern=args.pattern,
)
retcodes.append(retcode)
Expand Down
4 changes: 2 additions & 2 deletions scripts/tests_debug.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ if [ ! -z $1 ]
then
# test pattern was specified
echo 'Run unittests with pattern '$1':'
./run_tests.py --unit --runslow --nojit --pattern "$1"
./run_tests.py --unit --debug --runslow --nojit --pattern "$1"
else
# test pattern was not specified
echo 'Run all unittests:'
./run_tests.py --unit --nojit
./run_tests.py --unit --debug --nojit
fi
15 changes: 15 additions & 0 deletions tests/resources/run_pde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pde


def run_pde(t_range, storage):
"""run a pde and store trajectory"""
field = pde.ScalarField.random_uniform(pde.UnitGrid([8, 8]))
eq = pde.DiffusionPDE()
result = eq.solve(
field,
t_range=t_range,
dt=0.1,
backend="numpy",
tracker=pde.ModelrunnerStorage(storage).tracker(1),
)
return {"field": result}
Loading

0 comments on commit cbe668c

Please sign in to comment.