Skip to content

Commit

Permalink
Add parameters.reset_initial_value
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolaCourtier committed Aug 5, 2024
1 parent cd87e83 commit 14e4223
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pybop/optimisers/base_optimiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def set_base_options(self):
"""
# Set initial values, if x0 is None, initial values are unmodified.
self.parameters.update(initial_values=self.unset_options.pop("x0", None))
self.x0 = self.parameters.initial_value()
self.x0 = self.parameters.reset_initial_value()

# Set default bounds (for all or no parameters)
self.bounds = self.unset_options.pop("bounds", self.parameters.get_bounds())
Expand Down
20 changes: 16 additions & 4 deletions pybop/parameters/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,10 @@ def get_initial_value(self) -> float:
self.update(initial_value=sample)
else:
warnings.warn(
"Initial value or Prior are None, proceeding without initial value.",
"Initial value and prior are None, proceeding without an initial value.",
UserWarning,
stacklevel=2,
)
else:
# Make sure to always reset the current value as well
self.update(value=self.initial_value)

return self.initial_value

Expand Down Expand Up @@ -424,6 +421,21 @@ def initial_value(self) -> np.ndarray:

return np.asarray(initial_values)

def reset_initial_value(self) -> np.ndarray:
"""
Reset and return the initial value of each parameter.
"""
initial_values = []

for param in self.param.values():
initial_value = param.get_initial_value()
if initial_value is not None:
# Reset the current value as well
param.update(value=initial_value)
initial_values.append(initial_value)

return np.asarray(initial_values)

def current_value(self) -> np.ndarray:
"""
Return the current value of each parameter.
Expand Down
2 changes: 2 additions & 0 deletions pybop/problems/base_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(
)

self.parameters = parameters
self.parameters.reset_initial_value()

self._model = model
self.check_model = check_model
if isinstance(signal, str):
Expand Down
1 change: 0 additions & 1 deletion pybop/problems/design_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __init__(
super().__init__(parameters, model, check_model, signal, additional_variables)
self.experiment = experiment
self.init_soc = init_soc
self.parameters.initial_value()

# Add an example dataset for plotting comparison
sol = self.evaluate(self.parameters.as_dict("initial"))
Expand Down
1 change: 0 additions & 1 deletion pybop/problems/fitting_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(

super().__init__(parameters, model, check_model, signal, additional_variables)
self._dataset = dataset.data
self.parameters.initial_value()
self._n_parameters = len(self.parameters)
self._init_ocv = None
if init_ocv is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def test_initial_values_without_attributes(self):
)
with pytest.warns(
UserWarning,
match="Initial value or Prior are None, proceeding without initial value.",
match="Initial value and prior are None, proceeding without an initial value.",
):
sample = parameter.initial_value()

Expand Down

0 comments on commit 14e4223

Please sign in to comment.