Skip to content

Commit

Permalink
Add option to save data to an output directory
Browse files Browse the repository at this point in the history
  • Loading branch information
finsberg committed Sep 5, 2024
1 parent ae71392 commit a2586fd
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 47 deletions.
40 changes: 29 additions & 11 deletions examples/lv_coupling_3D0D.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
dt = 1e-3



def print_table(time, current_volume, target_volume, pressure):
from rich.table import Table

Expand Down Expand Up @@ -64,9 +63,8 @@ def get_cavity_volume_form(mesh, u=None, xshift=5.0):
return vol_form



def model(comm):
geodir = Path("lv_ellipsoid")
def model(comm, outdir: Path):
geodir = outdir / "geometry"

if not geodir.exists():
# Make sure we don't create the directory before all
Expand Down Expand Up @@ -172,7 +170,7 @@ def dirichlet_bc(

vtx = dolfinx.io.VTXWriter(
problem.geometry.mesh.comm,
"displacement.bp",
outdir / "displacement.bp",
[u],
engine="BP4",
)
Expand All @@ -195,10 +193,22 @@ def get_activation(t: float):
i = t * 1000 % 1000
return normal_activation[int(i)]

def callback(t: float, save=False):
def callback(model, t: float, save=False):
if save:
u.x.array[:] = problem.state.x.array
vtx.write(t)

fig, ax = plt.subplots(3, 1)

ax[0].plot(model.results["V_LV"], model.results["p_LV"])
ax[0].set_xlabel("V [mL]")
ax[0].set_ylabel("p [mmHg]")

ax[1].plot(model.results["time"], model.results["p_LV"])
ax[2].plot(model.results["time"], model.results["V_LV"])

fig.savefig(outdir / "pv_loop_incremental")

value = get_activation(t)

logger.debug(f"Time{t} with activation: {value}")
Expand All @@ -209,8 +219,8 @@ def callback(t: float, save=False):
volume_form = get_cavity_volume_form(geometry.mesh, u=u)
volume = dolfinx.fem.form(volume_form * geometry.ds(geometry.markers["ENDO"][0]))
initial_volume = geo.mesh.comm.allreduce(
dolfinx.fem.assemble_scalar(volume), op=MPI.SUM
)
dolfinx.fem.assemble_scalar(volume), op=MPI.SUM
)
logger.info(f"Initial volume: {initial_volume}")

@lru_cache
Expand Down Expand Up @@ -262,14 +272,22 @@ def p_LV_func(V_LV, t):


def main(comm):
callback, p_LV_func, initial_volume = model(comm)

outdir = Path("results-lv_coupling_3D0D")
outdir.mkdir(exist_ok=True)
callback, p_LV_func, initial_volume = model(comm, outdir=outdir)

mL = ureg.mL

add_units = False

circulation = Regazzoni2020(
add_units=add_units, callback=callback, p_LV_func=p_LV_func, verbose=True, comm=comm,
add_units=add_units,
callback=callback,
p_LV_func=p_LV_func,
verbose=True,
comm=comm,
outdir=outdir,
)

if add_units:
Expand All @@ -294,7 +312,7 @@ def main(comm):
ax[1].plot(history["time"], history["p_LV"])
ax[2].plot(history["time"], history["V_LV"])

fig.savefig("pv_loop")
fig.savefig(outdir / "pv_loop")


if __name__ == "__main__":
Expand Down
75 changes: 45 additions & 30 deletions src/circulation/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import Callable, Any, Protocol
from pathlib import Path
from abc import ABC, abstractmethod
import json
from collections import defaultdict
Expand Down Expand Up @@ -33,33 +34,47 @@ def remove_units(parameters: dict[str, Any]) -> dict[str, Any]:


class CallBack(Protocol):
def __call__(self, t: float, save: bool) -> None:
...
def __call__(self, model: "CirculationModel", t: float = 0, save: bool = True) -> None: ...


def dummy_callback(model: "CirculationModel", t: float = 0, save: bool = True) -> None:
pass


class CirculationModel(ABC):
def __init__(
self,
parameters: dict[str, Any] | None = None,
outdir: Path = Path("results"),
add_units: bool = False,
callback: Callable[[float, bool], None] | None = None,
callback: CallBack | None = None,
verbose: bool = False,
comm=None,
# save_state: Callable[[fl]]
callback_save_state: CallBack | None = None,
):
self.parameters = type(self).default_parameters()
if parameters is not None:
self.parameters.update(parameters)
if not add_units:
self.parameters = remove_units(self.parameters)
self._add_units = add_units
self.outdir = outdir
outdir.mkdir(exist_ok=True, parents=True)

if callback is not None:
assert callable(callback), "callback must be callable"

self.callback = callback
else:
self.callback = lambda t, b: None
self.callback = dummy_callback

if callback_save_state is not None:
assert callable(callback_save_state), "callback_save_state must be callable"

self.callback_save_state = callback_save_state
else:
self.callback_save_state = dummy_callback

self._verbose = verbose
loglevel = logging.DEBUG if verbose else logging.INFO
log.setup_logging(level=loglevel)
Expand All @@ -71,6 +86,12 @@ def _initialize(self):
self.update_state()
self.update_static_variables(0.0)

if self._comm is None or (self._comm is not None and self._comm.rank == 0):
# Dump parameters to file
(self.outdir / "parameters.json").write_text(json.dumps(self.parameters, indent=2))
# Dump initial conditions to file
(self.outdir / "initial_conditions.json").write_text(json.dumps(self.state, indent=2))

@property
def THB(self):
if self._add_units:
Expand All @@ -80,8 +101,7 @@ def THB(self):

@staticmethod
@abstractmethod
def default_parameters() -> dict[str, Any]:
...
def default_parameters() -> dict[str, Any]: ...

@abstractmethod
def update_static_variables(self, t: float):
Expand All @@ -96,8 +116,7 @@ def update_state(self, state: dict[str, float] | None = None):

@staticmethod
@abstractmethod
def default_initial_conditions() -> dict[str, float]:
...
def default_initial_conditions() -> dict[str, float]: ...

def time_varying_elastance(self, EA, EB, tC, TC, TR, **kwargs):
return time_varying_elastance.blanco_ventricle(
Expand Down Expand Up @@ -126,8 +145,7 @@ def _R(
)

@abstractmethod
def step(self, t: float, dt: float) -> None:
...
def step(self, t: float, dt: float) -> None: ...

def solve(
self,
Expand All @@ -136,6 +154,7 @@ def solve(
initial_state: dict[str, float] | None = None,
dt: float = 1e-3,
dt_eval: float | None = None,
checkpoint: int = 0,
):
logger.info("Running circulation model")
if T is None:
Expand All @@ -149,6 +168,11 @@ def solve(
else:
output_every_n_steps = np.round(dt_eval / dt)

if checkpoint > 0:
checkoint_every_n_steps = np.round(checkpoint / dt)
else:
checkoint_every_n_steps = np.inf

self.update_state(state=initial_state)
self.initialize_output()
t = 0.0
Expand All @@ -162,49 +186,40 @@ def solve(

i = 0
while t < T:
self.callback(t, i % output_every_n_steps == 0)
self.callback(self, t, i % output_every_n_steps == 0)
self.step(t, dt)
if i % output_every_n_steps == 0:
self.store(t)
if self._verbose:
self.print_info()

if i % checkoint_every_n_steps == 0:
self.save_state()
t += dt
i += 1

duration = time.time() - time_start

logger.info(f"Done running circulation model in {duration:.2f} s")
self.save_state()
return self.results

def initialize_output(self):
self.results = defaultdict(list)

def store(self, t):
get = lambda x: np.copy(x) if not self._add_units else x.magnitude
get = lambda x: float(np.copy(x)) if not self._add_units else x.magnitude

self.results["time"].append(get(t))
for k, v in self.state.items():
self.results[k].append(get(v))
for k, v in self.var.items():
self.results[k].append(get(v))

import matplotlib.pyplot as plt

if self._comm is not None and self._comm.rank == 0:
fig, ax = plt.subplots(3, 1)

ax[0].plot(self.results["V_LV"], self.results["p_LV"])
ax[0].set_xlabel("V [mL]")
ax[0].set_ylabel("p [mmHg]")

ax[1].plot(self.results["time"], self.results["p_LV"])
ax[2].plot(self.results["time"], self.results["V_LV"])

fig.savefig("pv_loop")

def save_state(self, filename):
with open(filename, mode="w", newline="") as outfile:
json.dump(self.state, outfile, indent=2)
def save_state(self):
self.callback_save_state(self)
(self.outdir / "state.json").write_text(json.dumps(self.state, indent=2))
(self.outdir / "results.json").write_text(json.dumps(self.results, indent=2))

@property
def volumes(self) -> dict[str, float]:
Expand Down
19 changes: 13 additions & 6 deletions src/circulation/regazzoni2020.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Callable, Any
from pathlib import Path

from . import base
from . import units
Expand All @@ -27,12 +28,18 @@ def __init__(
add_units=False,
p_LV_func: Callable[[float, float], float] | None = None,
leak: Callable[[float], float] | None = None,
callback: Callable[[float, float], None] | None = None,
callback: base.CallBack | None = None,
verbose: bool = False,
comm=None,
outdir: Path = Path("results-regazzoni"),
):
super().__init__(
parameters, add_units=add_units, callback=callback, verbose=verbose, comm=comm
parameters,
add_units=add_units,
callback=callback,
verbose=verbose,
comm=comm,
outdir=outdir,
)
chambers = self.parameters["chambers"]
valves = self.parameters["valves"]
Expand Down Expand Up @@ -141,7 +148,7 @@ def default_parameters() -> dict[str, Any]:
"R_AR": 0.8 * mmHg * s / mL,
"C_AR": 1.2 * mL / mmHg,
"R_VEN": 0.26 * mmHg * s / mL,
"C_VEN": 60.0 * mL / mmHg,
"C_VEN": 130 * mL / mmHg,
"L_AR": 5e-3 * mmHg * s**2 / mL,
"L_VEN": 5e-4 * mmHg * s**2 / mL,
},
Expand Down Expand Up @@ -220,16 +227,16 @@ def step(self, t, dt):
L_AR_PUL = self.parameters["circulation"]["PUL"]["L_AR"]
L_VEN_PUL = self.parameters["circulation"]["PUL"]["L_VEN"]

self.state["V_LA"] += dt * (Q_VEN_PUL - Q_MV - self.leak(t))
self.state["V_LA"] += dt * (Q_VEN_PUL - Q_MV)
self.state["V_LV"] += dt * (Q_MV - Q_AV)
self.state["V_RA"] += dt * (Q_VEN_SYS - Q_TV)
self.state["V_RV"] += dt * (Q_TV - Q_PV)
self.state["p_AR_SYS"] += dt * (Q_AV - Q_AR_SYS) / C_AR_SYS
self.state["p_VEN_SYS"] += dt * (Q_AR_SYS - Q_VEN_SYS) / C_VEN_SYS
self.state["p_AR_PUL"] += dt * (Q_PV - Q_AR_PUL) / C_AR_PUL
self.state["p_VEN_PUL"] += dt * (Q_AR_PUL - Q_VEN_PUL) / C_VEN_PUL
self.state["p_VEN_PUL"] += dt * (Q_AR_PUL - Q_VEN_PUL - self.leak(t)) / C_VEN_PUL
self.state["Q_AR_SYS"] += -dt * ((R_AR_SYS * Q_AR_SYS + p_VEN_SYS - p_AR_SYS) / L_AR_SYS)
self.state["Q_VEN_SYS"] += -dt * (R_VEN_SYS * Q_VEN_SYS + p_RA - p_VEN_SYS) / L_VEN_SYS
self.state["Q_VEN_SYS"] += -dt * ((R_VEN_SYS * Q_VEN_SYS + p_RA - p_VEN_SYS) / L_VEN_SYS)
self.state["Q_AR_PUL"] += -dt * (R_AR_PUL * Q_AR_PUL + p_VEN_PUL - p_AR_PUL) / L_AR_PUL
self.state["Q_VEN_PUL"] += -dt * (R_VEN_PUL * Q_VEN_PUL + p_LA - p_VEN_PUL) / L_VEN_PUL

Expand Down

0 comments on commit a2586fd

Please sign in to comment.