Skip to content

Commit

Permalink
Add set_current_function and test
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolaCourtier committed Jul 17, 2024
1 parent 5c370f5 commit 6c7c3c3
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 19 deletions.
41 changes: 25 additions & 16 deletions pybop/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
self.parameter_set = pybamm.ParameterValues(parameter_set.params)

self.parameters = Parameters()
self.output_variables = ["Time [s]", "Current [A]", "Voltage [V]"]
self.output_variables = []
self.rebuild_parameters = {}
self.standard_parameters = {}
self.param_check_counter = 0
Expand Down Expand Up @@ -125,6 +125,9 @@ def build(
if init_soc is not None:
self.set_init_soc(init_soc)

if dataset is not None:
self.set_current_function(dataset)

if self._built_model:
return

Expand All @@ -135,7 +138,7 @@ def build(
else:
if not self.pybamm_model._built:
self.pybamm_model.build_model()
self.set_params(dataset=dataset)
self.set_params()

self._mesh = pybamm.Mesh(self.geometry, self.submesh_types, self.var_pts)
self._disc = pybamm.Discretisation(self.mesh, self.spatial_methods)
Expand Down Expand Up @@ -170,7 +173,21 @@ def set_init_soc(self, init_soc: float):
for key in self.standard_parameters.keys():
self._parameter_set[key] = "[input]"

def set_params(self, rebuild=False, dataset=None):
def set_current_function(self, dataset: Dataset):
"""
Construct the current function from the dataset.
"""
if (
self._parameter_set is not None
and "Current function [A]" in self._parameter_set.keys()
):
self._parameter_set["Current function [A]"] = pybamm.Interpolant(
dataset["Time [s]"],
dataset["Current function [A]"],
pybamm.t,
)

def set_params(self, rebuild=False):
"""
Assign the parameters to the model.
Expand All @@ -180,17 +197,6 @@ def set_params(self, rebuild=False, dataset=None):
if self.model_with_set_params and not rebuild:
return

if dataset is not None and (not self.rebuild_parameters or not rebuild):
if (
self.parameters is None
or "Current function [A]" not in self.parameters.keys()
):
self._parameter_set["Current function [A]"] = pybamm.Interpolant(
dataset["Time [s]"],
dataset["Current function [A]"],
pybamm.t,
)

self._model_with_set_params = self._parameter_set.process_model(
self._unprocessed_model, inplace=False
)
Expand Down Expand Up @@ -230,10 +236,13 @@ def rebuild(
if init_soc is not None:
self.set_init_soc(init_soc)

if dataset is not None:
self.set_current_function(dataset)

if self._built_model is None:
raise ValueError("Model must be built before calling rebuild")

self.set_params(rebuild=True, dataset=dataset)
self.set_params(rebuild=True)
self._mesh = pybamm.Mesh(self.geometry, self.submesh_types, self.var_pts)
self._disc = pybamm.Discretisation(self.mesh, self.spatial_methods)
self._built_model = self._disc.process_model(
Expand Down Expand Up @@ -614,7 +623,7 @@ def check_params(

return self._check_params(
inputs=inputs,
parameter_set=parameter_set,
parameter_set=parameter_set or self._parameter_set,
allow_infeasible_solutions=allow_infeasible_solutions,
)

Expand Down
1 change: 1 addition & 0 deletions pybop/models/empirical/base_ecm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
self._mesh = None
self._disc = None
self.geometric_parameters = {}
self.output_variables = ["Time [s]", "Current [A]", "Voltage [V]"]

def _check_params(
self,
Expand Down
1 change: 1 addition & 0 deletions pybop/models/lithium_ion/base_echem.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(

self._electrode_soh = pybamm_lithium_ion.electrode_soh
self.geometric_parameters = self.set_geometric_parameters()
self.output_variables = ["Time [s]", "Current [A]", "Voltage [V]"]

def _check_params(
self,
Expand Down
2 changes: 1 addition & 1 deletion pybop/observers/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _evaluate(self, inputs: Inputs):

output = {}
ys = []
if hasattr(self, "_dataset"):
if self._dataset is not None:
for signal in self.signal:
ym = self._target[signal]
for i, t in enumerate(self._time_data):
Expand Down
3 changes: 2 additions & 1 deletion pybop/observers/unscented_kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import numpy as np
import scipy.linalg as linalg

from pybop.models.base_model import BaseModel, Dataset, Inputs
from pybop import Dataset
from pybop.models.base_model import BaseModel, Inputs
from pybop.observers.observer import Observer
from pybop.parameters.parameter import Parameter

Expand Down
1 change: 1 addition & 0 deletions pybop/problems/base_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
self.variables = list(set(self.variables))
self.init_soc = init_soc
self.n_outputs = len(self.signal)
self._dataset = None
self._time_data = None
self._target = None

Expand Down
2 changes: 1 addition & 1 deletion pybop/problems/design_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self,
model: BaseModel,
parameters: Parameters,
experiment: Experiment,
experiment: Optional[Experiment],
check_model: bool = True,
signal: Optional[list[str]] = None,
additional_variables: Optional[list[str]] = None,
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,26 @@ def test_set_init_soc(self):
np.testing.assert_allclose(
values_1["Voltage [V]"].data, values_2["Voltage [V]"].data, atol=1e-8
)

@pytest.mark.unit
def test_set_current_function(self):
t_eval = np.linspace(0, 10, 100)
dataset_1 = pybop.Dataset(
{"Time [s]": t_eval, "Current function [A]": np.ones(100)}
)
dataset_2 = pybop.Dataset(
{"Time [s]": t_eval, "Current function [A]": np.zeros(100)}
)

model = pybop.lithium_ion.SPM()
model.build(dataset=dataset_1)
values_1 = model.predict(t_eval=t_eval)

model = pybop.lithium_ion.SPM()
model.build(dataset=dataset_2)
model.set_current_function(dataset=dataset_1)
values_2 = model.predict(t_eval=t_eval)

np.testing.assert_allclose(
values_1["Voltage [V]"].data, values_2["Voltage [V]"].data, atol=1e-8
)

0 comments on commit 6c7c3c3

Please sign in to comment.