Skip to content

Commit

Permalink
feat: clean up attributes, implemented as properties for LogPosterior.
Browse files Browse the repository at this point in the history
  • Loading branch information
BradyPlanden committed Sep 18, 2024
1 parent 976872a commit f94dcc2
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 34 deletions.
38 changes: 24 additions & 14 deletions pybop/costs/_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,18 +232,12 @@ def __init__(
log_prior: Optional[Union[pybop.BasePrior, stats.rv_continuous]] = None,
gradient_step: float = 1e-3,
):
super().__init__(problem=log_likelihood.problem)
self.gradient_step = gradient_step

# Store the likelihood, prior, update parameters and transformation
self.join_parameters(log_likelihood.parameters)
self._log_likelihood = log_likelihood

for attr in ["transformation", "_has_separable_problem"]:
setattr(self, attr, getattr(log_likelihood, attr))
self.gradient_step = gradient_step
super().__init__(problem=log_likelihood.problem)

if log_prior is None:
self._prior = JointLogPrior(*self._parameters.priors())
self._prior = JointLogPrior(*self.parameters.priors())
else:
self._prior = log_prior

Expand Down Expand Up @@ -275,16 +269,16 @@ def compute(

if calculate_grad:
if isinstance(self._prior, BasePrior):
log_prior, dp = self._prior.logpdfS1(self._parameters.current_value())
log_prior, dp = self._prior.logpdfS1(self.parameters.current_value())
else:
# Compute log prior first
log_prior = self._prior.logpdf(self._parameters.current_value())
log_prior = self._prior.logpdf(self.parameters.current_value())

# Compute a finite difference approximation of the gradient of the log prior
delta = self._parameters.initial_value() * self.gradient_step
delta = self.parameters.initial_value() * self.gradient_step
dp = []

for parameter, step_size in zip(self._parameters, delta):
for parameter, step_size in zip(self.parameters, delta):
param_value = parameter.value
upper_value = param_value * (1 + step_size)
lower_value = param_value * (1 - step_size)
Expand All @@ -297,7 +291,7 @@ def compute(
)
dp.append(gradient)
else:
log_prior = self._prior.logpdf(self._parameters.current_value())
log_prior = self._prior.logpdf(self.parameters.current_value())

if not np.isfinite(log_prior).any():
return (-np.inf, -self.grad_fail) if calculate_grad else -np.inf
Expand All @@ -316,6 +310,22 @@ def compute(
posterior = log_likelihood + log_prior
return posterior

@property
def transformation(self):
return self._log_likelihood.transformation

@property
def has_separable_problem(self):
return self._log_likelihood.has_separable_problem

@property
def parameters(self):
return self._log_likelihood.parameters

@property
def n_parameters(self):
return self._log_likelihood.n_parameters

@property
def prior(self) -> BasePrior:
return self._prior
Expand Down
50 changes: 34 additions & 16 deletions pybop/costs/base_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class BaseCost:

def __init__(self, problem: Optional[BaseProblem] = None):
self._parameters = Parameters()
self.transformation = None
self._transformation = None
self.problem = problem
self.verbose = False
self._has_separable_problem = False
Expand All @@ -47,23 +47,11 @@ def __init__(self, problem: Optional[BaseProblem] = None):
self._parameters.join(self.problem.parameters)
self.n_outputs = self.problem.n_outputs
self.signal = self.problem.signal
self.transformation = self._parameters.construct_transformation()
self._transformation = self._parameters.construct_transformation()
self._has_separable_problem = True
self.grad_fail = None
self.set_fail_gradient()

@property
def n_parameters(self):
return len(self._parameters)

@property
def has_separable_problem(self):
return self._has_separable_problem

@property
def target(self):
return self._target

def __call__(
self,
inputs: Union[Inputs, list],
Expand Down Expand Up @@ -92,11 +80,13 @@ def __call__(
If an error occurs during the calculation of the cost.
"""
# Apply transformation if needed
# Note, we use the transformation and parameter properties here to enable
# differing attributes within the `LogPosterior` class
self.has_transform = self.transformation is not None and apply_transform
if self.has_transform:
inputs = self.transformation.to_model(inputs)
inputs = self._parameters.verify(inputs)
self._parameters.update(values=list(inputs.values()))
inputs = self.parameters.verify(inputs)
self.parameters.update(values=list(inputs.values()))

y, dy = None, None
if self._has_separable_problem:
Expand Down Expand Up @@ -193,6 +183,34 @@ def join_parameters(self, parameters):
if original_n_params != self.n_parameters:
self.set_fail_gradient()

@property
def n_parameters(self):
return len(self._parameters)

@property
def has_separable_problem(self):
return self._has_separable_problem

@has_separable_problem.setter
def has_separable_problem(self, has_separable_problem):
self._has_separable_problem = has_separable_problem

Check warning on line 196 in pybop/costs/base_cost.py

View check run for this annotation

Codecov / codecov/patch

pybop/costs/base_cost.py#L196

Added line #L196 was not covered by tests

@property
def target(self):
return self._target

@property
def parameters(self):
return self._parameters

@parameters.setter
def parameters(self, parameters):
self._parameters = parameters

@property
def transformation(self):
return self._transformation

@transformation.setter
def transformation(self, transformation):
self._transformation = transformation
10 changes: 6 additions & 4 deletions tests/unit/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,14 @@ def test_multi_log_pdf(self, log_posterior, x0, chains):
)

# Test incorrect number of parameters
new_multi_log_posterior = copy.copy(log_posterior)
new_multi_log_posterior._parameters = [
new_multi_log_posterior._parameters[
likelihood_copy = copy.copy(log_posterior.likelihood)
likelihood_copy.parameters = pybop.Parameters(
likelihood_copy.parameters[
"Positive electrode active material volume fraction"
]
]
)
new_multi_log_posterior = pybop.LogPosterior(likelihood_copy)

with pytest.raises(
ValueError, match="All log pdf's must have the same number of parameters"
):
Expand Down

0 comments on commit f94dcc2

Please sign in to comment.