From cbff886f5cbae5ddb1cf94d6442ea641cfcfebbb Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Thu, 3 Oct 2024 15:37:33 +0200 Subject: [PATCH] Make it possible to pass initial state into initializer --- src/circulation/base.py | 31 ++++++++++++++++++++----------- src/circulation/regazzoni2020.py | 2 ++ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/circulation/base.py b/src/circulation/base.py index f3ed924..5693bc5 100644 --- a/src/circulation/base.py +++ b/src/circulation/base.py @@ -86,21 +86,34 @@ def __init__( verbose: bool = False, comm=None, callback_save_state: CallBack | None = None, + initial_state: dict[str, float] | None = None, ): self.parameters = type(self).default_parameters() if parameters is not None: self.parameters = deep_update(self.parameters, parameters) + self._add_units = add_units + + self.state = type(self).default_initial_conditions() + if initial_state is not None: + self.state.update(initial_state) + table = Table(title=f"Circulation model parameters ({type(self).__name__})") table.add_column("Parameter") table.add_column("Value") recuursive_table(self.parameters, table) logger.info(f"\n{log.log_table(table)}") + table = Table(title=f"Circulation model initial states ({type(self).__name__})") + table.add_column("State") + table.add_column("Value") + recuursive_table(self.state, table) + logger.info(f"\n{log.log_table(table)}") + if not add_units: self.parameters = remove_units(self.parameters) + self.state = remove_units(self.state) - self._add_units = add_units self.outdir = outdir outdir.mkdir(exist_ok=True, parents=True) @@ -126,8 +139,7 @@ def __init__( def _initialize(self): self.var = {} self.results = defaultdict(list) - self.state = type(self).default_initial_conditions() - self.update_state() + self.update_static_variables(0.0) if self._comm is None or (self._comm is not None and self._comm.rank == 0): @@ -150,13 +162,6 @@ def default_parameters() -> dict[str, Any]: ... def update_static_variables(self, t: float): pass - def update_state(self, state: dict[str, float] | None = None): - if state is not None: - self.state.update(state) - - if not self._add_units: - self.state = remove_units(self.state) - @staticmethod @abstractmethod def default_initial_conditions() -> dict[str, float]: ... @@ -216,11 +221,15 @@ def solve( else: checkoint_every_n_steps = np.inf - self.update_state(state=initial_state) + if initial_state is not None: + self.state.update(initial_state) + t = 0.0 if self._add_units: t *= units.ureg("s") dt *= units.ureg("s") + else: + self.state = remove_units(self.state) self.store(t) diff --git a/src/circulation/regazzoni2020.py b/src/circulation/regazzoni2020.py index b03c78f..6f32433 100644 --- a/src/circulation/regazzoni2020.py +++ b/src/circulation/regazzoni2020.py @@ -31,6 +31,7 @@ def __init__( verbose: bool = False, comm=None, outdir: Path = Path("results-regazzoni"), + initial_state: dict[str, float] | None = None, ): super().__init__( parameters, @@ -39,6 +40,7 @@ def __init__( verbose=verbose, comm=comm, outdir=outdir, + initial_state=initial_state, ) chambers = self.parameters["chambers"] valves = self.parameters["valves"]