Skip to content

Commit

Permalink
tests: increase coverage, remove redundant ValueError, sampler.chains…
Browse files Browse the repository at this point in the history
… now optional arg with default=1
  • Loading branch information
BradyPlanden committed Aug 14, 2024
1 parent e50812a commit 7a000cf
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 36 deletions.
17 changes: 7 additions & 10 deletions pybop/samplers/base_pints_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class BasePintsSampler(BaseSampler):
def __init__(
self,
log_pdf: LogPosterior,
chains: int,
sampler,
chains: int = 1,
warm_up=None,
x0=None,
cov0=0.1,
Expand Down Expand Up @@ -95,15 +95,12 @@ def __init__(
self._single_chain = issubclass(sampler, SingleChainMCMC)

# Construct the samplers object
try:
if self._single_chain:
self._n_samplers = self._n_chains
self._samplers = [sampler(x0, sigma0=self._cov0) for x0 in self._x0]
else:
self._n_samplers = 1
self._samplers = [sampler(self._n_chains, self._x0, self._cov0)]
except Exception as e:
raise ValueError(f"Error constructing samplers: {e}") from e
if self._single_chain:
self._n_samplers = self._n_chains
self._samplers = [sampler(x0, sigma0=self._cov0) for x0 in self._x0]
else:
self._n_samplers = 1
self._samplers = [sampler(self._n_chains, self._x0, self._cov0)]

# Check for sensitivities from sampler and set evaluation
self._needs_sensitivities = self._samplers[0].needs_sensitivities()
Expand Down
6 changes: 0 additions & 6 deletions pybop/samplers/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@ def __init__(self, log_pdf: LogPosterior, x0, cov0: Union[np.ndarray, float]):
else np.asarray(x0, dtype=float)
)

# Validate x0 shape
if self._x0.ndim == 0:
raise ValueError(
f"x0 must be an array-like structure, but got a scalar: {x0}"
)

def run(self) -> np.ndarray:
"""
Sample from the posterior distribution.
Expand Down
38 changes: 20 additions & 18 deletions pybop/samplers/pints_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ class NUTS(BasePintsSampler):
"""

def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(log_pdf, chains, NoUTurnMCMC, x0=x0, cov0=cov0, **kwargs)
super().__init__(
log_pdf, NoUTurnMCMC, chains=chains, x0=x0, cov0=cov0, **kwargs
)


class DREAM(BasePintsSampler):
Expand Down Expand Up @@ -73,7 +75,7 @@ class DREAM(BasePintsSampler):
"""

def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(log_pdf, chains, PintsDREAM, x0=x0, cov0=cov0, **kwargs)
super().__init__(log_pdf, PintsDREAM, chains=chains, x0=x0, cov0=cov0, **kwargs)


class AdaptiveCovarianceMCMC(BasePintsSampler):
Expand Down Expand Up @@ -101,8 +103,8 @@ class AdaptiveCovarianceMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsAdaptiveCovarianceMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -134,8 +136,8 @@ class DifferentialEvolutionMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsDifferentialEvolutionMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -168,8 +170,8 @@ class DramACMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsDramACMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -201,8 +203,8 @@ class EmceeHammerMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsEmceeHammerMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -234,8 +236,8 @@ class HaarioACMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsHaarioACMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -267,8 +269,8 @@ class HaarioBardenetACMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsHaarioBardenetACMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -300,8 +302,8 @@ class HamiltonianMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsHamiltonianMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -333,8 +335,8 @@ class MALAMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsMALAMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -366,8 +368,8 @@ class MetropolisRandomWalkMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsMetropolisRandomWalkMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -399,8 +401,8 @@ class MonomialGammaHamiltonianMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsMonomialGammaHamiltonianMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -432,8 +434,8 @@ class PopulationMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsPopulationMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -465,8 +467,8 @@ class RaoBlackwellACMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsRaoBlackwellACMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -498,8 +500,8 @@ class RelativisticMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsRelativisticMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -531,8 +533,8 @@ class SliceDoublingMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsSliceDoublingMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -564,8 +566,8 @@ class SliceRankShrinkingMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsSliceRankShrinkingMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down Expand Up @@ -597,8 +599,8 @@ class SliceStepoutMCMC(BasePintsSampler):
def __init__(self, log_pdf, chains, x0=None, cov0=None, **kwargs):
super().__init__(
log_pdf,
chains,
PintsSliceStepoutMCMC,
chains=chains,
x0=x0,
cov0=cov0,
**kwargs,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def multi_samplers(self):
],
)
@pytest.mark.unit
def test_initialization_and_run(
def test_initialisation_and_run(
self, log_posterior, x0, chains, MCMC, multi_samplers
):
sampler = pybop.MCMCSampler(
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_multi_log_pdf(self, log_posterior, x0, chains):
)

@pytest.mark.unit
def test_invalid_initialization(self, log_posterior, x0):
def test_invalid_initialisation(self, log_posterior, x0):
with pytest.raises(ValueError, match="Number of chains must be greater than 0"):
AdaptiveCovarianceMCMC(
log_pdf=log_posterior,
Expand Down

0 comments on commit 7a000cf

Please sign in to comment.