Skip to content

Commit

Permalink
Merge pull request #315 from pybop-team/model-rebuild-parameters-refa…
Browse files Browse the repository at this point in the history
…ctor

Renames `model.rebuild` variables, updates import structure, adds type-hints to BaseModel
  • Loading branch information
BradyPlanden authored Jun 6, 2024
2 parents 02cbb14 + 1020428 commit f32c383
Show file tree
Hide file tree
Showing 23 changed files with 234 additions and 188 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- [#315](https://github.com/pybop-team/PyBOP/pull/315) - Updates __init__ structure to remove circular import issues and minimises dependancy imports across codebase for faster PyBOP module import. Adds type-hints to BaseModel and refactors rebuild parameter variables.
- [#236](https://github.com/pybop-team/PyBOP/issues/236) - Restructures the optimiser classes, adds a new optimisation API through direct construction and keyword arguments, and fixes the setting of `max_iterations`, and `_minimising`. Introduces `pybop.BaseOptimiser`, `pybop.BasePintsOptimiser`, and `pybop.BaseSciPyOptimiser` classes.
- [#321](https://github.com/pybop-team/PyBOP/pull/321) - Updates Prior classes with BaseClass, adds a `problem.sample_initial_conditions` method to improve stability of SciPy.Minimize optimiser.
- [#249](https://github.com/pybop-team/PyBOP/pull/249) - Add WeppnerHuggins model and GITT example.
Expand Down
2 changes: 1 addition & 1 deletion examples/standalone/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def __init__(
self._built_initial_soc = None
self._mesh = None
self._disc = None
self.rebuild_parameters = {}
self.geometric_parameters = {}
55 changes: 27 additions & 28 deletions pybop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,33 @@
from ._utils import is_numeric

#
# Problem classes
# Experiment class
#
from ._experiment import Experiment

#
# Dataset class
#
from ._dataset import Dataset

#
# Parameter classes
#
from .parameters.parameter import Parameter, Parameters
from .parameters.parameter_set import ParameterSet
from .parameters.priors import BasePrior, Gaussian, Uniform, Exponential

#
# Model classes
#
from .models.base_model import BaseModel
from .models import lithium_ion
from .models import empirical
from .models.base_model import TimeSeriesState
from .models.base_model import Inputs

#
# Problem class
#
from .problems.base_problem import BaseProblem
from .problems.fitting_problem import FittingProblem
Expand All @@ -73,25 +99,6 @@
GaussianLogLikelihoodKnownSigma,
)

#
# Dataset class
#
from ._dataset import Dataset

#
# Model classes
#
from .models.base_model import BaseModel
from .models import lithium_ion
from .models import empirical
from .models.base_model import TimeSeriesState
from .models.base_model import Inputs

#
# Experiment class
#
from ._experiment import Experiment

#
# Optimiser class
#
Expand All @@ -114,14 +121,6 @@
)
from .optimisers.optimisation import Optimisation

#
# Parameter classes
#
from .parameters.parameter import Parameter, Parameters
from .parameters.parameter_set import ParameterSet
from .parameters.priors import BasePrior, Gaussian, Uniform, Exponential


#
# Observer classes
#
Expand Down
7 changes: 4 additions & 3 deletions pybop/_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pybamm
from pybamm import Interpolant, solvers
from pybamm import t as pybamm_t


class Dataset:
Expand All @@ -20,7 +21,7 @@ def __init__(self, data_dictionary):
Initialize a Dataset instance with data and a set of names.
"""

if isinstance(data_dictionary, pybamm.solvers.solution.Solution):
if isinstance(data_dictionary, solvers.solution.Solution):
data_dictionary = data_dictionary.get_data_dict()
if not isinstance(data_dictionary, dict):
raise TypeError("The input to pybop.Dataset must be a dictionary.")
Expand Down Expand Up @@ -91,7 +92,7 @@ def Interpolant(self):
"""

if self.variable == "time":
self.Interpolant = pybamm.Interpolant(self.x, self.y, pybamm.t)
self.Interpolant = Interpolant(self.x, self.y, pybamm_t)
else:
NotImplementedError("Only time interpolation is supported")

Expand Down
103 changes: 56 additions & 47 deletions pybop/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import casadi
import numpy as np
import pybamm

from pybop import Dataset, Experiment, Parameters, ParameterSet

Inputs = Dict[str, float]


Expand Down Expand Up @@ -63,13 +65,12 @@ def __init__(self, name="Base Model", parameter_set=None):
else: # a pybop parameter set
self._parameter_set = pybamm.ParameterValues(parameter_set.params)

self.pybamm_model = None
self.parameters = None
self.dataset = None
self.signal = None
self.additional_variables = []
self.matched_parameters = {}
self.non_matched_parameters = {}
self.rebuild_parameters = {}
self.standard_parameters = {}
self.param_check_counter = 0
self.allow_infeasible_solutions = True

Expand All @@ -79,11 +80,11 @@ def n_parameters(self):

def build(
self,
dataset=None,
parameters=None,
check_model=True,
init_soc=None,
):
dataset: Dataset = None,
parameters: Union[Parameters, Dict] = None,
check_model: bool = True,
init_soc: float = None,
) -> None:
"""
Construct the PyBaMM model if not already built, and set parameters.
Expand All @@ -95,8 +96,8 @@ def build(
----------
dataset : pybamm.Dataset, optional
The dataset to be used in the model construction.
parameters : dict, optional
A dictionary containing parameter values to apply to the model.
parameters : pybop.Parameters or Dict, optional
A pybop Parameters class or dictionary containing parameter values to apply to the model.
check_model : bool, optional
If True, the model will be checked for correctness after construction.
init_soc : float, optional
Expand Down Expand Up @@ -133,7 +134,7 @@ def build(

self.n_states = self._built_model.len_rhs_and_alg # len_rhs + len_alg

def set_init_soc(self, init_soc):
def set_init_soc(self, init_soc: float):
"""
Set the initial state of charge for the battery model.
Expand Down Expand Up @@ -169,10 +170,10 @@ def set_params(self, rebuild=False):
return

# Mark any simulation inputs in the parameter set
for key in self.non_matched_parameters.keys():
for key in self.standard_parameters.keys():
self._parameter_set[key] = "[input]"

if self.dataset is not None and (not self.matched_parameters or not rebuild):
if self.dataset is not None and (not self.rebuild_parameters or not rebuild):
if (
self.parameters is None
or "Current function [A]" not in self.parameters.keys()
Expand All @@ -194,12 +195,12 @@ def set_params(self, rebuild=False):

def rebuild(
self,
dataset=None,
parameters=None,
parameter_set=None,
check_model=True,
init_soc=None,
):
dataset: Dataset = None,
parameters: Union[Parameters, Dict] = None,
parameter_set: ParameterSet = None,
check_model: bool = True,
init_soc: float = None,
) -> None:
"""
Rebuild the PyBaMM model for a given parameter set.
Expand All @@ -212,8 +213,8 @@ def rebuild(
----------
dataset : pybamm.Dataset, optional
The dataset to be used in the model construction.
parameters : dict, optional
A dictionary containing parameter values to apply to the model.
parameters : pybop.Parameters or Dict, optional
A pybop Parameters class or dictionary containing parameter values to apply to the model.
parameter_set : pybop.parameter_set, optional
A PyBOP parameter set object or a dictionary containing the parameter values
check_model : bool, optional
Expand Down Expand Up @@ -242,34 +243,35 @@ def rebuild(
# Clear solver and setup model
self._solver._model_set_up = {}

def classify_and_update_parameters(self, parameters):
def classify_and_update_parameters(self, parameters: Union[Parameters, Dict]):
"""
Update the parameter values according to their classification as either
'matched_parameters' which require a model rebuild and
'non_matched_parameters' which are standard inputs.
'rebuild_parameters' which require a model rebuild and
'standard_parameters' which do not.
Parameters
----------
parameters : pybop.ParameterSet
"""
parameter_dictionary = parameters.as_dict()
matched_parameters = {
rebuild_parameters = {
param: parameter_dictionary[param]
for param in parameter_dictionary
if param in self.rebuild_parameters
if param in self.geometric_parameters
}
non_matched_parameters = {
standard_parameters = {
param: parameter_dictionary[param]
for param in parameter_dictionary
if param not in self.rebuild_parameters
if param not in self.geometric_parameters
}

self.matched_parameters.update(matched_parameters)
self.non_matched_parameters.update(non_matched_parameters)
self.rebuild_parameters.update(rebuild_parameters)
self.standard_parameters.update(standard_parameters)

if self.matched_parameters:
self._parameter_set.update(self.matched_parameters)
# Update the parameter set and geometry for rebuild parameters
if self.rebuild_parameters:
self._parameter_set.update(self.rebuild_parameters)
self._unprocessed_parameter_set = self._parameter_set
self.geometry = self.pybamm_model.default_geometry

Expand Down Expand Up @@ -322,7 +324,9 @@ def step(self, state: TimeSeriesState, time: np.ndarray) -> TimeSeriesState:
)
return TimeSeriesState(sol=new_sol, inputs=state.inputs, t=time)

def simulate(self, inputs, t_eval) -> np.ndarray[np.float64]:
def simulate(
self, inputs: Inputs, t_eval: np.array
) -> Dict[str, np.ndarray[np.float64]]:
"""
Execute the forward model simulation and return the result.
Expand All @@ -347,7 +351,7 @@ def simulate(self, inputs, t_eval) -> np.ndarray[np.float64]:
if self._built_model is None:
raise ValueError("Model must be built before calling simulate")
else:
if self.matched_parameters and not self.non_matched_parameters:
if self.rebuild_parameters and not self.standard_parameters:
sol = self.solver.solve(self.built_model, t_eval=t_eval)

else:
Expand Down Expand Up @@ -375,7 +379,7 @@ def simulate(self, inputs, t_eval) -> np.ndarray[np.float64]:

return y

def simulateS1(self, inputs, t_eval):
def simulateS1(self, inputs: Inputs, t_eval: np.array):
"""
Perform the forward model simulation with sensitivities.
Expand All @@ -402,7 +406,7 @@ def simulateS1(self, inputs, t_eval):
if self._built_model is None:
raise ValueError("Model must be built before calling simulate")
else:
if self.matched_parameters:
if self.rebuild_parameters:
raise ValueError(
"Cannot use sensitivies for parameters which require a model rebuild"
)
Expand Down Expand Up @@ -451,12 +455,12 @@ def simulateS1(self, inputs, t_eval):

def predict(
self,
inputs=None,
t_eval=None,
parameter_set=None,
experiment=None,
init_soc=None,
):
inputs: Inputs = None,
t_eval: np.array = None,
parameter_set: ParameterSet = None,
experiment: Experiment = None,
init_soc: float = None,
) -> Dict[str, np.ndarray[np.float64]]:
"""
Solve the model using PyBaMM's simulation framework and return the solution.
Expand Down Expand Up @@ -530,7 +534,10 @@ def predict(
return [np.inf]

def check_params(
self, inputs=None, parameter_set=None, allow_infeasible_solutions=True
self,
inputs: Inputs = None,
parameter_set: ParameterSet = None,
allow_infeasible_solutions: bool = True,
):
"""
Check compatibility of the model parameters.
Expand Down Expand Up @@ -564,7 +571,9 @@ def check_params(
inputs=inputs, allow_infeasible_solutions=allow_infeasible_solutions
)

def _check_params(self, inputs=None, allow_infeasible_solutions=True):
def _check_params(
self, inputs: Inputs = None, allow_infeasible_solutions: bool = True
):
"""
A compatibility check for the model parameters which can be implemented by subclasses
if required, otherwise it returns True by default.
Expand Down Expand Up @@ -594,7 +603,7 @@ def copy(self):
"""
return copy.copy(self)

def cell_mass(self, parameter_set=None):
def cell_mass(self, parameter_set: ParameterSet = None):
"""
Calculate the cell mass in kilograms.
Expand All @@ -613,7 +622,7 @@ def cell_mass(self, parameter_set=None):
"""
raise NotImplementedError

def cell_volume(self, parameter_set=None):
def cell_volume(self, parameter_set: ParameterSet = None):
"""
Calculate the cell volume in m3.
Expand Down
Loading

0 comments on commit f32c383

Please sign in to comment.