From 14e4223f22ef5eb48b361e6e1061b98dca03c196 Mon Sep 17 00:00:00 2001 From: NicolaCourtier <45851982+NicolaCourtier@users.noreply.github.com> Date: Mon, 5 Aug 2024 17:51:29 +0100 Subject: [PATCH] Add parameters.reset_initial_value --- pybop/optimisers/base_optimiser.py | 2 +- pybop/parameters/parameter.py | 20 ++++++++++++++++---- pybop/problems/base_problem.py | 2 ++ pybop/problems/design_problem.py | 1 - pybop/problems/fitting_problem.py | 1 - tests/unit/test_parameters.py | 2 +- 6 files changed, 20 insertions(+), 8 deletions(-) diff --git a/pybop/optimisers/base_optimiser.py b/pybop/optimisers/base_optimiser.py index c6d68e81..3b491974 100644 --- a/pybop/optimisers/base_optimiser.py +++ b/pybop/optimisers/base_optimiser.py @@ -129,7 +129,7 @@ def set_base_options(self): """ # Set initial values, if x0 is None, initial values are unmodified. self.parameters.update(initial_values=self.unset_options.pop("x0", None)) - self.x0 = self.parameters.initial_value() + self.x0 = self.parameters.reset_initial_value() # Set default bounds (for all or no parameters) self.bounds = self.unset_options.pop("bounds", self.parameters.get_bounds()) diff --git a/pybop/parameters/parameter.py b/pybop/parameters/parameter.py index 026c9f66..5631dfbc 100644 --- a/pybop/parameters/parameter.py +++ b/pybop/parameters/parameter.py @@ -194,13 +194,10 @@ def get_initial_value(self) -> float: self.update(initial_value=sample) else: warnings.warn( - "Initial value or Prior are None, proceeding without initial value.", + "Initial value and prior are None, proceeding without an initial value.", UserWarning, stacklevel=2, ) - else: - # Make sure to always reset the current value as well - self.update(value=self.initial_value) return self.initial_value @@ -424,6 +421,21 @@ def initial_value(self) -> np.ndarray: return np.asarray(initial_values) + def reset_initial_value(self) -> np.ndarray: + """ + Reset and return the initial value of each parameter. + """ + initial_values = [] + + for param in self.param.values(): + initial_value = param.get_initial_value() + if initial_value is not None: + # Reset the current value as well + param.update(value=initial_value) + initial_values.append(initial_value) + + return np.asarray(initial_values) + def current_value(self) -> np.ndarray: """ Return the current value of each parameter. diff --git a/pybop/problems/base_problem.py b/pybop/problems/base_problem.py index 2eab6c1b..7111224f 100644 --- a/pybop/problems/base_problem.py +++ b/pybop/problems/base_problem.py @@ -50,6 +50,8 @@ def __init__( ) self.parameters = parameters + self.parameters.reset_initial_value() + self._model = model self.check_model = check_model if isinstance(signal, str): diff --git a/pybop/problems/design_problem.py b/pybop/problems/design_problem.py index 6e3f5a76..3f94dce8 100644 --- a/pybop/problems/design_problem.py +++ b/pybop/problems/design_problem.py @@ -56,7 +56,6 @@ def __init__( super().__init__(parameters, model, check_model, signal, additional_variables) self.experiment = experiment self.init_soc = init_soc - self.parameters.initial_value() # Add an example dataset for plotting comparison sol = self.evaluate(self.parameters.as_dict("initial")) diff --git a/pybop/problems/fitting_problem.py b/pybop/problems/fitting_problem.py index e6f5cf22..b165c58b 100644 --- a/pybop/problems/fitting_problem.py +++ b/pybop/problems/fitting_problem.py @@ -45,7 +45,6 @@ def __init__( super().__init__(parameters, model, check_model, signal, additional_variables) self._dataset = dataset.data - self.parameters.initial_value() self._n_parameters = len(self.parameters) self._init_ocv = None if init_ocv is not None: diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index e48cda8e..1a3806aa 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -224,7 +224,7 @@ def test_initial_values_without_attributes(self): ) with pytest.warns( UserWarning, - match="Initial value or Prior are None, proceeding without initial value.", + match="Initial value and prior are None, proceeding without an initial value.", ): sample = parameter.initial_value()