Skip to content

Commit

Permalink
Add HMC sampling state
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Sep 25, 2024
1 parent b3097f4 commit 137ad63
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 20 deletions.
32 changes: 26 additions & 6 deletions pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@
from pymc.model import Point, modelcontext
from pymc.pytensorf import floatX
from pymc.stats.convergence import SamplerWarning, WarningType
from pymc.step_methods import step_sizes
from pymc.step_methods.arraystep import GradientSharedStep
from pymc.step_methods.compound import StepMethodState
from pymc.step_methods.hmc import integration
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
from pymc.step_methods.hmc.quadpotential import (
PotentialState,
QuadPotentialDiagAdapt,
quad_potential,
)
from pymc.step_methods.step_sizes import DualAverageAdaptation, StepSizeState
from pymc.tuning import guess_scaling
from pymc.util import get_value_vars_from_user_vars
from pymc.util import RandomGenerator, get_random_generator, get_value_vars_from_user_vars

logger = logging.getLogger(__name__)

Expand All @@ -53,12 +57,26 @@ class HMCStepData(NamedTuple):
stats: dict[str, Any]


class BaseHMCState(StepMethodState):
adapt_step_size: bool
Emax: float
iter_count: int
step_size: np.ndarray
step_adapt: StepSizeState
target_accept: float
tune: bool
potential: PotentialState
_num_divs_sample: int


class BaseHMC(GradientSharedStep):
"""Superclass to implement Hamiltonian/hybrid monte carlo."""

integrator: integration.CpuLeapfrogIntegrator
default_blocked = True

_state_class = BaseHMCState

def __init__(
self,
vars=None,
Expand Down Expand Up @@ -126,9 +144,7 @@ def __init__(
size = sum(v.size for v in nuts_vars)

self.step_size = step_scale / (size**0.25)
self.step_adapt = step_sizes.DualAverageAdaptation(
self.step_size, target_accept, gamma, k, t0
)
self.step_adapt = DualAverageAdaptation(self.step_size, target_accept, gamma, k, t0)
self.target_accept = target_accept
self.tune = True

Expand Down Expand Up @@ -260,3 +276,7 @@ def reset_tuning(self, start=None):
def reset(self, start=None):
self.tune = True
self.potential.reset()

def set_rng(self, rng: RandomGenerator):
self.rng = get_random_generator(rng, copy=False)
self.potential.set_rng(self.rng.spawn(1)[0])
8 changes: 7 additions & 1 deletion pymc/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

from __future__ import annotations

from dataclasses import field
from typing import Any

import numpy as np

from pymc.stats.convergence import SamplerWarning
from pymc.step_methods.compound import Competence
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.vartypes import discrete_types

Expand All @@ -31,6 +32,11 @@ def unif(step_size, elow=0.85, ehigh=1.15, rng: np.random.Generator | None = Non
return (rng or np.random).uniform(elow, ehigh) * step_size


class HamiltonianMCState(BaseHMCState):
path_length: float = field(metadata={"frozen": True})
max_steps: int = field(metadata={"frozen": True})


class HamiltonianMC(BaseHMC):
R"""A sampler for continuous variables based on Hamiltonian mechanics.
Expand Down
8 changes: 7 additions & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from collections import namedtuple
from dataclasses import field

import numpy as np

Expand All @@ -23,13 +24,18 @@
from pymc.stats.convergence import SamplerWarning
from pymc.step_methods.compound import Competence
from pymc.step_methods.hmc import integration
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.vartypes import continuous_types

__all__ = ["NUTS"]


class NUTSState(BaseHMCState):
max_treedepth: int = field(metadata={"frozen": True})
early_max_treedepth: int = field(metadata={"frozen": True})


class NUTS(BaseHMC):
r"""A sampler for continuous variables based on Hamiltonian mechanics.
Expand Down
116 changes: 105 additions & 11 deletions pymc/step_methods/hmc/quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import warnings

from typing import overload
from dataclasses import field
from typing import Any, overload

import numpy as np
import pytensor
Expand All @@ -25,6 +26,8 @@
from scipy.sparse import issparse

from pymc.pytensorf import floatX
from pymc.step_methods.state import DataClassState, WithSamplingState
from pymc.util import RandomGenerator, get_random_generator

__all__ = [
"quad_potential",
Expand Down Expand Up @@ -96,11 +99,17 @@ def __str__(self):
return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}."


class QuadPotential:
class PotentialState(DataClassState):
rng: np.random.Generator


class QuadPotential(WithSamplingState):
dtype: np.dtype

_state_class = PotentialState

def __init__(self, rng=None):
self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

@overload
def velocity(self, x: np.ndarray, out: None) -> np.ndarray: ...
Expand Down Expand Up @@ -153,15 +162,41 @@ def reset(self):
def stats(self):
return {"largest_eigval": np.nan, "smallest_eigval": np.nan}

def set_rng(self, rng: RandomGenerator):
self.rng = get_random_generator(rng, copy=False)


def isquadpotential(value):
"""Check whether an object might be a QuadPotential object."""
return isinstance(value, QuadPotential)


class QuadPotentialDiagAdaptState(PotentialState):
_var: np.ndarray
_stds: np.ndarray
_inv_stds: np.ndarray
_foreground_var: WeightedVarianceState
_background_var: WeightedVarianceState
_n_samples: int
adaptation_window: int
_mass_trace: list[np.ndarray] | None

dtype: Any = field(metadata={"frozen": True})
_n: int = field(metadata={"frozen": True})
_discard_window: int = field(metadata={"frozen": True})
_early_update: int = field(metadata={"frozen": True})
_initial_mean: np.ndarray = field(metadata={"frozen": True})
_initial_diag: np.ndarray = field(metadata={"frozen": True})
_initial_weight: np.ndarray = field(metadata={"frozen": True})
adaptation_window_multiplier: float = field(metadata={"frozen": True})
_store_mass_matrix_trace: bool = field(metadata={"frozen": True})


class QuadPotentialDiagAdapt(QuadPotential):
"""Adapt a diagonal mass matrix from the sample variances."""

_state_class = QuadPotentialDiagAdaptState

def __init__(
self,
n,
Expand Down Expand Up @@ -342,9 +377,19 @@ def raise_ok(self, map_info):
raise ValueError("\n".join(errmsg))


class _WeightedVariance:
class WeightedVarianceState(DataClassState):
n_samples: int
mean: np.ndarray
raw_var: np.ndarray

_dtype: Any = field(metadata={"frozen": True})


class _WeightedVariance(WithSamplingState):
"""Online algorithm for computing mean of variance."""

_state_class = WeightedVarianceState

def __init__(
self, nelem, initial_mean=None, initial_variance=None, initial_weight=0, dtype="d"
):
Expand Down Expand Up @@ -386,7 +431,15 @@ def current_mean(self):
return self.mean.copy(dtype=self._dtype)


class _ExpWeightedVariance:
class ExpWeightedVarianceState(DataClassState):
_alpha: float
_mean: np.ndarray
_var: np.ndarray


class _ExpWeightedVariance(WithSamplingState):
_state_class = ExpWeightedVarianceState

def __init__(self, n_vars, *, init_mean, init_var, alpha):
self._variance = init_var
self._mean = init_mean
Expand All @@ -411,7 +464,17 @@ def current_mean(self, out=None):
return out


class QuadPotentialDiagAdaptExpState(QuadPotentialDiagAdaptState):
_alpha: float
_stop_adaptation: float
_variance_estimator: ExpWeightedVarianceState

_variance_estimator_grad: ExpWeightedVarianceState | None = None


class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt):
_state_class = QuadPotentialDiagAdaptExpState

def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, **kwargs):
"""Set up a diagonal mass matrix.
Expand Down Expand Up @@ -512,7 +575,7 @@ def __init__(self, v, dtype=None, rng=None):
self.s = s
self.inv_s = 1.0 / s
self.v = v
self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

def velocity(self, x, out=None):
"""Compute the current velocity at a position in parameter space."""
Expand Down Expand Up @@ -552,7 +615,7 @@ def __init__(self, A, dtype=None, rng=None):
dtype = pytensor.config.floatX
self.dtype = dtype
self.L = floatX(scipy.linalg.cholesky(A, lower=True))
self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

def velocity(self, x, out=None):
"""Compute the current velocity at a position in parameter space."""
Expand Down Expand Up @@ -595,7 +658,7 @@ def __init__(self, cov, dtype=None, rng=None):
self._cov = np.array(cov, dtype=self.dtype, copy=True)
self._chol = scipy.linalg.cholesky(self._cov, lower=True)
self._n = len(self._cov)
self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

def velocity(self, x, out=None):
"""Compute the current velocity at a position in parameter space."""
Expand All @@ -620,9 +683,30 @@ def velocity_energy(self, x, v_out):
__call__ = random


class QuadPotentialFullAdaptState(PotentialState):
_previous_update: int
_cov: np.ndarray
_chol: np.ndarray
_chol_error: scipy.linalg.LinAlgError | ValueError | None = None
_foreground_cov: WeightedCovarianceState
_background_cov: WeightedCovarianceState
_n_samples: int
adaptation_window: int

dtype: Any = field(metadata={"frozen": True})
_n: int = field(metadata={"frozen": True})
_update_window: int = field(metadata={"frozen": True})
_initial_mean: np.ndarray = field(metadata={"frozen": True})
_initial_cov: np.ndarray = field(metadata={"frozen": True})
_initial_weight: np.ndarray = field(metadata={"frozen": True})
adaptation_window_multiplier: float = field(metadata={"frozen": True})


class QuadPotentialFullAdapt(QuadPotentialFull):
"""Adapt a dense mass matrix using the sample covariances."""

_state_class = QuadPotentialFullAdaptState

def __init__(
self,
n,
Expand Down Expand Up @@ -663,7 +747,7 @@ def __init__(
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
self._update_window = int(update_window)

self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

self.reset()

Expand Down Expand Up @@ -716,7 +800,15 @@ def raise_ok(self, vmap):
raise ValueError(str(self._chol_error))


class _WeightedCovariance:
class WeightedCovarianceState(DataClassState):
n_samples: float
mean: np.ndarray
raw_cov: np.ndarray

_dtype: Any = field(metadata={"frozen": True})


class _WeightedCovariance(WithSamplingState):
"""Online algorithm for computing mean and covariance
This implements the `Welford's algorithm
Expand All @@ -726,6 +818,8 @@ class _WeightedCovariance:
"""

_state_class = WeightedCovarianceState

def __init__(
self,
nelem,
Expand Down Expand Up @@ -797,7 +891,7 @@ def __init__(self, A, rng=None):
self.size = A.shape[0]
self.factor = factor = cholmod.cholesky(A)
self.d_sqrt = np.sqrt(factor.D())
self.rng = np.random.default_rng(rng)
self.rng = get_random_generator(rng)

Check warning on line 894 in pymc/step_methods/hmc/quadpotential.py

View check run for this annotation

Codecov / codecov/patch

pymc/step_methods/hmc/quadpotential.py#L894

Added line #L894 was not covered by tests

def velocity(self, x):
"""Compute the current velocity at a position in parameter space."""
Expand Down
20 changes: 19 additions & 1 deletion pymc/step_methods/step_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np

from scipy import stats

from pymc.stats.convergence import SamplerWarning, WarningType
from pymc.step_methods.state import DataClassState, WithSamplingState


class StepSizeState(DataClassState):
_log_step: np.ndarray
_log_bar: np.ndarray
_hbar: float
_count: int
_mu: np.ndarray
_tuned_stats: list
_initial_step: np.ndarray
_target: float
_k: float
_t0: float
_gamma: float


class DualAverageAdaptation(WithSamplingState):
_state_class = StepSizeState

class DualAverageAdaptation:
def __init__(self, initial_step, target, gamma, k, t0):
self._initial_step = initial_step
self._target = target
Expand Down

0 comments on commit 137ad63

Please sign in to comment.