diff --git a/pybop/models/base_model.py b/pybop/models/base_model.py index d4b5fac5..b62b4dbe 100644 --- a/pybop/models/base_model.py +++ b/pybop/models/base_model.py @@ -87,7 +87,7 @@ def __init__( self.parameter_set = pybamm.ParameterValues(parameter_set.params) self.parameters = Parameters() - self.output_variables = ["Time [s]", "Current [A]", "Voltage [V]"] + self.output_variables = [] self.rebuild_parameters = {} self.standard_parameters = {} self.param_check_counter = 0 @@ -125,6 +125,9 @@ def build( if init_soc is not None: self.set_init_soc(init_soc) + if dataset is not None: + self.set_current_function(dataset) + if self._built_model: return @@ -135,7 +138,7 @@ def build( else: if not self.pybamm_model._built: self.pybamm_model.build_model() - self.set_params(dataset=dataset) + self.set_params() self._mesh = pybamm.Mesh(self.geometry, self.submesh_types, self.var_pts) self._disc = pybamm.Discretisation(self.mesh, self.spatial_methods) @@ -170,7 +173,21 @@ def set_init_soc(self, init_soc: float): for key in self.standard_parameters.keys(): self._parameter_set[key] = "[input]" - def set_params(self, rebuild=False, dataset=None): + def set_current_function(self, dataset: Dataset): + """ + Construct the current function from the dataset. + """ + if ( + self._parameter_set is not None + and "Current function [A]" in self._parameter_set.keys() + ): + self._parameter_set["Current function [A]"] = pybamm.Interpolant( + dataset["Time [s]"], + dataset["Current function [A]"], + pybamm.t, + ) + + def set_params(self, rebuild=False): """ Assign the parameters to the model. @@ -180,17 +197,6 @@ def set_params(self, rebuild=False, dataset=None): if self.model_with_set_params and not rebuild: return - if dataset is not None and (not self.rebuild_parameters or not rebuild): - if ( - self.parameters is None - or "Current function [A]" not in self.parameters.keys() - ): - self._parameter_set["Current function [A]"] = pybamm.Interpolant( - dataset["Time [s]"], - dataset["Current function [A]"], - pybamm.t, - ) - self._model_with_set_params = self._parameter_set.process_model( self._unprocessed_model, inplace=False ) @@ -230,10 +236,13 @@ def rebuild( if init_soc is not None: self.set_init_soc(init_soc) + if dataset is not None: + self.set_current_function(dataset) + if self._built_model is None: raise ValueError("Model must be built before calling rebuild") - self.set_params(rebuild=True, dataset=dataset) + self.set_params(rebuild=True) self._mesh = pybamm.Mesh(self.geometry, self.submesh_types, self.var_pts) self._disc = pybamm.Discretisation(self.mesh, self.spatial_methods) self._built_model = self._disc.process_model( @@ -614,7 +623,7 @@ def check_params( return self._check_params( inputs=inputs, - parameter_set=parameter_set, + parameter_set=parameter_set or self._parameter_set, allow_infeasible_solutions=allow_infeasible_solutions, ) diff --git a/pybop/models/empirical/base_ecm.py b/pybop/models/empirical/base_ecm.py index 509235f4..8d0c72d1 100644 --- a/pybop/models/empirical/base_ecm.py +++ b/pybop/models/empirical/base_ecm.py @@ -85,6 +85,7 @@ def __init__( self._mesh = None self._disc = None self.geometric_parameters = {} + self.output_variables = ["Time [s]", "Current [A]", "Voltage [V]"] def _check_params( self, diff --git a/pybop/models/lithium_ion/base_echem.py b/pybop/models/lithium_ion/base_echem.py index e424f04e..b28a363f 100644 --- a/pybop/models/lithium_ion/base_echem.py +++ b/pybop/models/lithium_ion/base_echem.py @@ -83,6 +83,7 @@ def __init__( self._electrode_soh = pybamm_lithium_ion.electrode_soh self.geometric_parameters = self.set_geometric_parameters() + self.output_variables = ["Time [s]", "Current [A]", "Voltage [V]"] def _check_params( self, diff --git a/pybop/observers/observer.py b/pybop/observers/observer.py index d1c85ce6..66b92250 100644 --- a/pybop/observers/observer.py +++ b/pybop/observers/observer.py @@ -155,7 +155,7 @@ def _evaluate(self, inputs: Inputs): output = {} ys = [] - if hasattr(self, "_dataset"): + if self._dataset is not None: for signal in self.signal: ym = self._target[signal] for i, t in enumerate(self._time_data): diff --git a/pybop/observers/unscented_kalman.py b/pybop/observers/unscented_kalman.py index 966849c3..df81b084 100644 --- a/pybop/observers/unscented_kalman.py +++ b/pybop/observers/unscented_kalman.py @@ -4,7 +4,8 @@ import numpy as np import scipy.linalg as linalg -from pybop.models.base_model import BaseModel, Dataset, Inputs +from pybop import Dataset +from pybop.models.base_model import BaseModel, Inputs from pybop.observers.observer import Observer from pybop.parameters.parameter import Parameter diff --git a/pybop/problems/base_problem.py b/pybop/problems/base_problem.py index 102046a7..7ed36e87 100644 --- a/pybop/problems/base_problem.py +++ b/pybop/problems/base_problem.py @@ -73,6 +73,7 @@ def __init__( self.variables = list(set(self.variables)) self.init_soc = init_soc self.n_outputs = len(self.signal) + self._dataset = None self._time_data = None self._target = None diff --git a/pybop/problems/design_problem.py b/pybop/problems/design_problem.py index 63cc7974..5fd86636 100644 --- a/pybop/problems/design_problem.py +++ b/pybop/problems/design_problem.py @@ -34,7 +34,7 @@ def __init__( self, model: BaseModel, parameters: Parameters, - experiment: Experiment, + experiment: Optional[Experiment], check_model: bool = True, signal: Optional[list[str]] = None, additional_variables: Optional[list[str]] = None, diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index cdf6ee8b..b5def40a 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -383,3 +383,26 @@ def test_set_init_soc(self): np.testing.assert_allclose( values_1["Voltage [V]"].data, values_2["Voltage [V]"].data, atol=1e-8 ) + + @pytest.mark.unit + def test_set_current_function(self): + t_eval = np.linspace(0, 10, 100) + dataset_1 = pybop.Dataset( + {"Time [s]": t_eval, "Current function [A]": np.ones(100)} + ) + dataset_2 = pybop.Dataset( + {"Time [s]": t_eval, "Current function [A]": np.zeros(100)} + ) + + model = pybop.lithium_ion.SPM() + model.build(dataset=dataset_1) + values_1 = model.predict(t_eval=t_eval) + + model = pybop.lithium_ion.SPM() + model.build(dataset=dataset_2) + model.set_current_function(dataset=dataset_1) + values_2 = model.predict(t_eval=t_eval) + + np.testing.assert_allclose( + values_1["Voltage [V]"].data, values_2["Voltage [V]"].data, atol=1e-8 + )