From a2586fdd5a9e655066608a29cee0bf6a938a3ff8 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Thu, 5 Sep 2024 22:54:55 +0200 Subject: [PATCH] Add option to save data to an output directory --- examples/lv_coupling_3D0D.py | 40 ++++++++++++----- src/circulation/base.py | 75 +++++++++++++++++++------------- src/circulation/regazzoni2020.py | 19 +++++--- 3 files changed, 87 insertions(+), 47 deletions(-) diff --git a/examples/lv_coupling_3D0D.py b/examples/lv_coupling_3D0D.py index 43e7bf9..765684b 100644 --- a/examples/lv_coupling_3D0D.py +++ b/examples/lv_coupling_3D0D.py @@ -27,7 +27,6 @@ dt = 1e-3 - def print_table(time, current_volume, target_volume, pressure): from rich.table import Table @@ -64,9 +63,8 @@ def get_cavity_volume_form(mesh, u=None, xshift=5.0): return vol_form - -def model(comm): - geodir = Path("lv_ellipsoid") +def model(comm, outdir: Path): + geodir = outdir / "geometry" if not geodir.exists(): # Make sure we don't create the directory before all @@ -172,7 +170,7 @@ def dirichlet_bc( vtx = dolfinx.io.VTXWriter( problem.geometry.mesh.comm, - "displacement.bp", + outdir / "displacement.bp", [u], engine="BP4", ) @@ -195,10 +193,22 @@ def get_activation(t: float): i = t * 1000 % 1000 return normal_activation[int(i)] - def callback(t: float, save=False): + def callback(model, t: float, save=False): if save: u.x.array[:] = problem.state.x.array vtx.write(t) + + fig, ax = plt.subplots(3, 1) + + ax[0].plot(model.results["V_LV"], model.results["p_LV"]) + ax[0].set_xlabel("V [mL]") + ax[0].set_ylabel("p [mmHg]") + + ax[1].plot(model.results["time"], model.results["p_LV"]) + ax[2].plot(model.results["time"], model.results["V_LV"]) + + fig.savefig(outdir / "pv_loop_incremental") + value = get_activation(t) logger.debug(f"Time{t} with activation: {value}") @@ -209,8 +219,8 @@ def callback(t: float, save=False): volume_form = get_cavity_volume_form(geometry.mesh, u=u) volume = dolfinx.fem.form(volume_form * geometry.ds(geometry.markers["ENDO"][0])) initial_volume = geo.mesh.comm.allreduce( - dolfinx.fem.assemble_scalar(volume), op=MPI.SUM - ) + dolfinx.fem.assemble_scalar(volume), op=MPI.SUM + ) logger.info(f"Initial volume: {initial_volume}") @lru_cache @@ -262,14 +272,22 @@ def p_LV_func(V_LV, t): def main(comm): - callback, p_LV_func, initial_volume = model(comm) + + outdir = Path("results-lv_coupling_3D0D") + outdir.mkdir(exist_ok=True) + callback, p_LV_func, initial_volume = model(comm, outdir=outdir) mL = ureg.mL add_units = False circulation = Regazzoni2020( - add_units=add_units, callback=callback, p_LV_func=p_LV_func, verbose=True, comm=comm, + add_units=add_units, + callback=callback, + p_LV_func=p_LV_func, + verbose=True, + comm=comm, + outdir=outdir, ) if add_units: @@ -294,7 +312,7 @@ def main(comm): ax[1].plot(history["time"], history["p_LV"]) ax[2].plot(history["time"], history["V_LV"]) - fig.savefig("pv_loop") + fig.savefig(outdir / "pv_loop") if __name__ == "__main__": diff --git a/src/circulation/base.py b/src/circulation/base.py index cff5940..fd1c10b 100644 --- a/src/circulation/base.py +++ b/src/circulation/base.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import Callable, Any, Protocol +from pathlib import Path from abc import ABC, abstractmethod import json from collections import defaultdict @@ -33,19 +34,23 @@ def remove_units(parameters: dict[str, Any]) -> dict[str, Any]: class CallBack(Protocol): - def __call__(self, t: float, save: bool) -> None: - ... + def __call__(self, model: "CirculationModel", t: float = 0, save: bool = True) -> None: ... + + +def dummy_callback(model: "CirculationModel", t: float = 0, save: bool = True) -> None: + pass class CirculationModel(ABC): def __init__( self, parameters: dict[str, Any] | None = None, + outdir: Path = Path("results"), add_units: bool = False, - callback: Callable[[float, bool], None] | None = None, + callback: CallBack | None = None, verbose: bool = False, comm=None, - # save_state: Callable[[fl]] + callback_save_state: CallBack | None = None, ): self.parameters = type(self).default_parameters() if parameters is not None: @@ -53,13 +58,23 @@ def __init__( if not add_units: self.parameters = remove_units(self.parameters) self._add_units = add_units + self.outdir = outdir + outdir.mkdir(exist_ok=True, parents=True) if callback is not None: assert callable(callback), "callback must be callable" self.callback = callback else: - self.callback = lambda t, b: None + self.callback = dummy_callback + + if callback_save_state is not None: + assert callable(callback_save_state), "callback_save_state must be callable" + + self.callback_save_state = callback_save_state + else: + self.callback_save_state = dummy_callback + self._verbose = verbose loglevel = logging.DEBUG if verbose else logging.INFO log.setup_logging(level=loglevel) @@ -71,6 +86,12 @@ def _initialize(self): self.update_state() self.update_static_variables(0.0) + if self._comm is None or (self._comm is not None and self._comm.rank == 0): + # Dump parameters to file + (self.outdir / "parameters.json").write_text(json.dumps(self.parameters, indent=2)) + # Dump initial conditions to file + (self.outdir / "initial_conditions.json").write_text(json.dumps(self.state, indent=2)) + @property def THB(self): if self._add_units: @@ -80,8 +101,7 @@ def THB(self): @staticmethod @abstractmethod - def default_parameters() -> dict[str, Any]: - ... + def default_parameters() -> dict[str, Any]: ... @abstractmethod def update_static_variables(self, t: float): @@ -96,8 +116,7 @@ def update_state(self, state: dict[str, float] | None = None): @staticmethod @abstractmethod - def default_initial_conditions() -> dict[str, float]: - ... + def default_initial_conditions() -> dict[str, float]: ... def time_varying_elastance(self, EA, EB, tC, TC, TR, **kwargs): return time_varying_elastance.blanco_ventricle( @@ -126,8 +145,7 @@ def _R( ) @abstractmethod - def step(self, t: float, dt: float) -> None: - ... + def step(self, t: float, dt: float) -> None: ... def solve( self, @@ -136,6 +154,7 @@ def solve( initial_state: dict[str, float] | None = None, dt: float = 1e-3, dt_eval: float | None = None, + checkpoint: int = 0, ): logger.info("Running circulation model") if T is None: @@ -149,6 +168,11 @@ def solve( else: output_every_n_steps = np.round(dt_eval / dt) + if checkpoint > 0: + checkoint_every_n_steps = np.round(checkpoint / dt) + else: + checkoint_every_n_steps = np.inf + self.update_state(state=initial_state) self.initialize_output() t = 0.0 @@ -162,25 +186,29 @@ def solve( i = 0 while t < T: - self.callback(t, i % output_every_n_steps == 0) + self.callback(self, t, i % output_every_n_steps == 0) self.step(t, dt) if i % output_every_n_steps == 0: self.store(t) if self._verbose: self.print_info() + + if i % checkoint_every_n_steps == 0: + self.save_state() t += dt i += 1 duration = time.time() - time_start logger.info(f"Done running circulation model in {duration:.2f} s") + self.save_state() return self.results def initialize_output(self): self.results = defaultdict(list) def store(self, t): - get = lambda x: np.copy(x) if not self._add_units else x.magnitude + get = lambda x: float(np.copy(x)) if not self._add_units else x.magnitude self.results["time"].append(get(t)) for k, v in self.state.items(): @@ -188,23 +216,10 @@ def store(self, t): for k, v in self.var.items(): self.results[k].append(get(v)) - import matplotlib.pyplot as plt - - if self._comm is not None and self._comm.rank == 0: - fig, ax = plt.subplots(3, 1) - - ax[0].plot(self.results["V_LV"], self.results["p_LV"]) - ax[0].set_xlabel("V [mL]") - ax[0].set_ylabel("p [mmHg]") - - ax[1].plot(self.results["time"], self.results["p_LV"]) - ax[2].plot(self.results["time"], self.results["V_LV"]) - - fig.savefig("pv_loop") - - def save_state(self, filename): - with open(filename, mode="w", newline="") as outfile: - json.dump(self.state, outfile, indent=2) + def save_state(self): + self.callback_save_state(self) + (self.outdir / "state.json").write_text(json.dumps(self.state, indent=2)) + (self.outdir / "results.json").write_text(json.dumps(self.results, indent=2)) @property def volumes(self) -> dict[str, float]: diff --git a/src/circulation/regazzoni2020.py b/src/circulation/regazzoni2020.py index f7cb6d1..cf2aaca 100644 --- a/src/circulation/regazzoni2020.py +++ b/src/circulation/regazzoni2020.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Callable, Any +from pathlib import Path from . import base from . import units @@ -27,12 +28,18 @@ def __init__( add_units=False, p_LV_func: Callable[[float, float], float] | None = None, leak: Callable[[float], float] | None = None, - callback: Callable[[float, float], None] | None = None, + callback: base.CallBack | None = None, verbose: bool = False, comm=None, + outdir: Path = Path("results-regazzoni"), ): super().__init__( - parameters, add_units=add_units, callback=callback, verbose=verbose, comm=comm + parameters, + add_units=add_units, + callback=callback, + verbose=verbose, + comm=comm, + outdir=outdir, ) chambers = self.parameters["chambers"] valves = self.parameters["valves"] @@ -141,7 +148,7 @@ def default_parameters() -> dict[str, Any]: "R_AR": 0.8 * mmHg * s / mL, "C_AR": 1.2 * mL / mmHg, "R_VEN": 0.26 * mmHg * s / mL, - "C_VEN": 60.0 * mL / mmHg, + "C_VEN": 130 * mL / mmHg, "L_AR": 5e-3 * mmHg * s**2 / mL, "L_VEN": 5e-4 * mmHg * s**2 / mL, }, @@ -220,16 +227,16 @@ def step(self, t, dt): L_AR_PUL = self.parameters["circulation"]["PUL"]["L_AR"] L_VEN_PUL = self.parameters["circulation"]["PUL"]["L_VEN"] - self.state["V_LA"] += dt * (Q_VEN_PUL - Q_MV - self.leak(t)) + self.state["V_LA"] += dt * (Q_VEN_PUL - Q_MV) self.state["V_LV"] += dt * (Q_MV - Q_AV) self.state["V_RA"] += dt * (Q_VEN_SYS - Q_TV) self.state["V_RV"] += dt * (Q_TV - Q_PV) self.state["p_AR_SYS"] += dt * (Q_AV - Q_AR_SYS) / C_AR_SYS self.state["p_VEN_SYS"] += dt * (Q_AR_SYS - Q_VEN_SYS) / C_VEN_SYS self.state["p_AR_PUL"] += dt * (Q_PV - Q_AR_PUL) / C_AR_PUL - self.state["p_VEN_PUL"] += dt * (Q_AR_PUL - Q_VEN_PUL) / C_VEN_PUL + self.state["p_VEN_PUL"] += dt * (Q_AR_PUL - Q_VEN_PUL - self.leak(t)) / C_VEN_PUL self.state["Q_AR_SYS"] += -dt * ((R_AR_SYS * Q_AR_SYS + p_VEN_SYS - p_AR_SYS) / L_AR_SYS) - self.state["Q_VEN_SYS"] += -dt * (R_VEN_SYS * Q_VEN_SYS + p_RA - p_VEN_SYS) / L_VEN_SYS + self.state["Q_VEN_SYS"] += -dt * ((R_VEN_SYS * Q_VEN_SYS + p_RA - p_VEN_SYS) / L_VEN_SYS) self.state["Q_AR_PUL"] += -dt * (R_AR_PUL * Q_AR_PUL + p_VEN_PUL - p_AR_PUL) / L_AR_PUL self.state["Q_VEN_PUL"] += -dt * (R_VEN_PUL * Q_VEN_PUL + p_LA - p_VEN_PUL) / L_VEN_PUL