Skip to content

Commit

Permalink
Make it possible to solve model in fenicsx by creating separate metho…
Browse files Browse the repository at this point in the history
…d for computing rhs
  • Loading branch information
finsberg committed Oct 7, 2024
1 parent bb024cb commit 900d5dd
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 143 deletions.
26 changes: 22 additions & 4 deletions examples/regazzoni.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,29 @@
setup_logging()
circulation = Regazzoni2020()

circulation.print_info()
from scipy.integrate import solve_ivp
import numpy as np


# res = solve_ivp(
# circulation.rhs,
# [0, 5],
# circulation.state_arr,
# t_eval=np.linspace(0, 5, 1000),
# method="RK45",
# max_step=1e-3,
# )
# plt.plot(res.y[0, :])
# plt.plot(res.y[1, :])
# plt.show()
# breakpoint()



# circulation.print_info()
# history = circulation.results

history = circulation.solve(
num_cycles=10,
)
history = circulation.solve(num_beats=10)
circulation.print_info()

fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 5))
Expand Down
170 changes: 109 additions & 61 deletions src/circulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from abc import ABC, abstractmethod
import json
from collections import defaultdict

import numpy as np
import time
import logging
Expand All @@ -17,8 +17,13 @@
logger = logging.getLogger(__name__)


def smooth_heavyside(x):
return np.arctan(np.pi / 2 * x * 200) * 1 / np.pi + 0.5
def smooth_heavyside(x, use_ufl=True):
if use_ufl:
import ufl

return ufl.atan(ufl.pi / 2 * x * 200) * 1 / ufl.pi + 0.5
else:
return np.arctan(np.pi / 2 * x * 200) * 1 / np.pi + 0.5


def remove_units(parameters: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -52,8 +57,7 @@ def external_blood(


class CallBack(Protocol):
def __call__(self, model: "CirculationModel", t: float = 0, save: bool = True) -> None:
...
def __call__(self, model: "CirculationModel", t: float = 0, save: bool = True) -> None: ...


def dummy_callback(model: "CirculationModel", t: float = 0, save: bool = True) -> None:
Expand Down Expand Up @@ -88,16 +92,17 @@ def __init__(
comm=None,
callback_save_state: CallBack | None = None,
initial_state: dict[str, float] | None = None,
theta: float = 0.5,
):
self.parameters = type(self).default_parameters()
if parameters is not None:
self.parameters = deep_update(self.parameters, parameters)

self._add_units = add_units
self._theta = theta

self.state = type(self).default_initial_conditions()
if initial_state is not None:
self.state.update(initial_state)
self._initial_state = type(self).default_initial_conditions()
self.update_inital_state(initial_state)

table = Table(title=f"Circulation model parameters ({type(self).__name__})")
table.add_column("Parameter")
Expand All @@ -108,12 +113,12 @@ def __init__(
table = Table(title=f"Circulation model initial states ({type(self).__name__})")
table.add_column("State")
table.add_column("Value")
recuursive_table(self.state, table)
recuursive_table(self._initial_state, table)
logger.info(f"\n{log.log_table(table)}")

if not add_units:
self.parameters = remove_units(self.parameters)
self.state = remove_units(self.state)
self.dy = np.zeros_like(self.state)

self.outdir = outdir
outdir.mkdir(exist_ok=True, parents=True)
Expand All @@ -137,17 +142,61 @@ def __init__(
log.setup_logging(level=loglevel)
self._comm = comm

def _initialize(self):
self.var = {}
self.results = defaultdict(list)
def update_static_variables_external(self, t: float, y: np.ndarray) -> None: ...

def update_inital_state(self, state: dict[str, float] | None = None):
if state is not None:
self._initial_state.update(state)

self.state = np.array(list(remove_units(self._initial_state).values()), dtype=np.float64)
self.state_old = np.copy(self.state)

@property
def states_names(self):
return list(self._initial_state.keys())

@property
def num_states(self):
return len(self.state)

@staticmethod
@abstractmethod
def var_names():
return []

@abstractmethod
def rhs(self, t: float, y: np.ndarray) -> np.ndarray: ...

@property
def num_vars(self):
return len(type(self).var_names())

self.update_static_variables(0.0)
@property
def state_theta(self):
return self._theta * self.state + (1 - self._theta) * self.state_old

def _initialize(self):
self.var = np.zeros(self.num_vars, dtype=np.float64)
if self._comm is None or (self._comm is not None and self._comm.rank == 0):
# Dump parameters to file
(self.outdir / "parameters.json").write_text(json.dumps(self.parameters, indent=2))
# Dump initial conditions to file
(self.outdir / "initial_conditions.json").write_text(json.dumps(self.state, indent=2))
(self.outdir / "initial_conditions.json").write_text(
json.dumps(remove_units(self._initial_state), indent=2)
)

def initialize_results(self, num_beats: int, dt_eval: float):
self.times = np.arange(0, num_beats / self.HR + dt_eval, dt_eval)
N = len(self.times)
self.results_state = np.zeros((self.num_states, N))
self.results_state[:, 0] = self.state

self.rhs(0.0, self.state)

self.results_var = np.zeros((self.num_vars, N))
self.results_var[:, 0] = self.var

self._index = 1

@property
@abstractmethod
Expand All @@ -157,17 +206,11 @@ def HR(self) -> float:

@staticmethod
@abstractmethod
def default_parameters() -> dict[str, Any]:
...

@abstractmethod
def update_static_variables(self, t: float):
pass
def default_parameters() -> dict[str, Any]: ...

@staticmethod
@abstractmethod
def default_initial_conditions() -> dict[str, float]:
...
def default_initial_conditions() -> dict[str, float]: ...

def time_varying_elastance(self, EA, EB, tC, TC, TR, **kwargs):
return time_varying_elastance.blanco_ventricle(
Expand Down Expand Up @@ -195,83 +238,88 @@ def _R(
* smooth_heavyside((v - w) / unit_p)
)

def times_one_beat(self, dt: float) -> np.ndarray:
return np.arange(0, 1 / self.HR, dt)

@abstractmethod
def step(self, t: float, dt: float) -> None:
...
def step(self, t: float, dt: float) -> None: ...

def solve(
self,
T: float | None = None,
num_cycles: int | None = None,
num_beats: int = 1,
initial_state: dict[str, float] | None = None,
dt: float = 1e-3,
dt_eval: float | None = None,
checkpoint: int = 0,
):
logger.info("Running circulation model")
if T is None:
assert num_cycles is not None, "Please provide num_cycles or T"
T = self.HR * num_cycles

initial_state = initial_state or dict()

if dt_eval is None:
output_every_n_steps = 1
else:
output_every_n_steps = np.round(dt_eval / dt)
dt_eval = dt

output_every_n_steps = np.round(dt_eval / dt)
self.initialize_results(num_beats, dt_eval)

if checkpoint > 0:
checkoint_every_n_steps = np.round(checkpoint / dt)
else:
checkoint_every_n_steps = np.inf

if initial_state is not None:
self.state.update(initial_state)
self.update_inital_state(initial_state)

t = 0.0
if self._add_units:
t *= units.ureg("s")
dt *= units.ureg("s")
else:
self.state = remove_units(self.state)

self.store(t)

time_start = time.time()

i = 0
while t < T:
self.callback(self, t, i % output_every_n_steps == 0)
self.step(t, dt)
if i % output_every_n_steps == 0:
self.store(t)
if self._verbose:
self.print_info()
for beat in range(num_beats):
logger.info(f"Solving beat {beat}")
for i, t in enumerate(self.times_one_beat(dt)):
self.callback(self, t, False)
self.step(t, dt)
if i % output_every_n_steps == 0:
self.store()
if self._verbose:
self.print_info()

if i % checkoint_every_n_steps == 0:
self.save_state()
t += dt
i += 1
if i % checkoint_every_n_steps == 0:
self.save_state()

duration = time.time() - time_start

logger.info(f"Done running circulation model in {duration:.2f} s")
self.save_state()
return self.results

def store(self, t):
get = lambda x: float(np.copy(x)) if not self._add_units else x.magnitude
return self.history

self.results["time"].append(get(t))
for k, v in self.state.items():
self.results[k].append(get(v))
for k, v in self.var.items():
self.results[k].append(get(v))
@property
def history(self):
history = {}
for i, name in enumerate(self.states_names):
history[name] = self.results_state[i, :]
for i, name in enumerate(type(self).var_names()):
history[name] = self.results_var[i, :]
history["time"] = self.times
return history

def store(self):
self.results_state[:, self._index] = self.state[:]
self.results_var[:, self._index] = self.var[:]
self._index += 1

def save_state(self):
self.callback_save_state(self)
(self.outdir / "state.json").write_text(json.dumps(self.state, indent=2))
(self.outdir / "results.json").write_text(json.dumps(self.results, indent=2))

np.savetxt(self.outdir / "state.txt", self.state)
np.savetxt(self.outdir / "results_state.txt", self.results_state)
np.savetxt(self.outdir / "results_var.txt", self.results_var)
np.savetxt(self.outdir / "time.txt", self.times)
np.savetxt(self.outdir / "var_names.txt", self.var_names(), fmt="%s")
np.savetxt(self.outdir / "state_names.txt", self.states_names, fmt="%s")

@property
def volumes(self) -> dict[str, float]:
Expand Down
Loading

0 comments on commit 900d5dd

Please sign in to comment.