From f51546de8a39b8fc9b0d730c385996bdf97464ae Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Tue, 10 Jan 2023 20:32:56 +0000 Subject: [PATCH 1/3] initial implentation of nutpie --- Project.toml | 4 + adapt_strategy.rs | 572 +++++++++++++++++++++++++++++++ example.jl | 78 +++++ nutpie.jl | 56 +++ rewrite.jl | 134 ++++++++ src/adaptation/Adaptation.jl | 15 +- src/adaptation/massmatrix.jl | 133 +++++-- src/adaptation/nutpie_adaptor.jl | 115 +++++++ src/sampler.jl | 46 ++- 9 files changed, 1114 insertions(+), 39 deletions(-) create mode 100644 adapt_strategy.rs create mode 100644 example.jl create mode 100644 nutpie.jl create mode 100644 rewrite.jl create mode 100644 src/adaptation/nutpie_adaptor.jl diff --git a/Project.toml b/Project.toml index 35412a9b..1c1723ed 100644 --- a/Project.toml +++ b/Project.toml @@ -5,11 +5,15 @@ version = "0.4.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/adapt_strategy.rs b/adapt_strategy.rs new file mode 100644 index 00000000..1d694e5f --- /dev/null +++ b/adapt_strategy.rs @@ -0,0 +1,572 @@ +use std::{fmt::Debug, marker::PhantomData}; + +use itertools::izip; + +use crate::{ + cpu_potential::{CpuLogpFunc, EuclideanPotential}, + mass_matrix::{ + DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance, + }, + nuts::{ + AdaptStrategy, AsSampleStatVec, Collector, Hamiltonian, NutsOptions, SampleStatItem, + SampleStatValue, + }, + stepsize::{AcceptanceRateCollector, DualAverage, DualAverageOptions}, +}; + +const LOWER_LIMIT: f64 = 1e-10f64; +const UPPER_LIMIT: f64 = 1e10f64; + +pub(crate) struct DualAverageStrategy { + step_size_adapt: DualAverage, + options: DualAverageSettings, + enabled: bool, + finalized: bool, + _phantom1: PhantomData, + _phantom2: PhantomData, +} + +impl DualAverageStrategy { + fn enable(&mut self) { + self.enabled = true; + } + + fn finalize(&mut self) { + self.finalized = true; + } +} + + +#[derive(Debug, Clone, Copy)] +pub struct DualAverageStats { + pub step_size_bar: f64, + pub mean_tree_accept: f64, + pub n_steps: u64, +} + +impl AsSampleStatVec for DualAverageStats { + fn add_to_vec(&self, vec: &mut Vec) { + vec.push(("step_size_bar", SampleStatValue::F64(self.step_size_bar))); + vec.push(( + "mean_tree_accept", + SampleStatValue::F64(self.mean_tree_accept), + )); + vec.push(("n_steps", SampleStatValue::U64(self.n_steps))); + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DualAverageSettings { + pub target_accept: f64, + pub initial_step: f64, + pub params: DualAverageOptions, +} + +impl Default for DualAverageSettings { + fn default() -> Self { + Self { + target_accept: 0.8, + initial_step: 0.1, + params: DualAverageOptions::default(), + } + } +} + +impl AdaptStrategy for DualAverageStrategy { + type Potential = EuclideanPotential; + type Collector = AcceptanceRateCollector; + type Stats = DualAverageStats; + type Options = DualAverageSettings; + + fn new(options: Self::Options, _num_tune: u64, _dim: usize) -> Self { + Self { + options, + enabled: true, + step_size_adapt: DualAverage::new(options.params, options.initial_step), + finalized: false, + _phantom1: PhantomData::default(), + _phantom2: PhantomData::default(), + } + } + + fn init( + &mut self, + _options: &mut NutsOptions, + potential: &mut Self::Potential, + _state: &::State, + ) { + potential.step_size = self.options.initial_step; + } + + fn adapt( + &mut self, + _options: &mut NutsOptions, + potential: &mut Self::Potential, + _draw: u64, + collector: &Self::Collector, + ) { + if self.finalized { + self.step_size_adapt + .advance(collector.mean.current(), self.options.target_accept); + potential.step_size = self.step_size_adapt.current_step_size_adapted(); + return; + } + if !self.enabled { + return; + } + self.step_size_adapt + .advance(collector.mean.current(), self.options.target_accept); + potential.step_size = self.step_size_adapt.current_step_size() + } + + fn new_collector(&self) -> Self::Collector { + AcceptanceRateCollector::new() + } + + fn current_stats( + &self, + _options: &NutsOptions, + _potential: &Self::Potential, + collector: &Self::Collector, + ) -> Self::Stats { + DualAverageStats { + step_size_bar: self.step_size_adapt.current_step_size_adapted(), + mean_tree_accept: collector.mean.current(), + n_steps: collector.mean.count(), + } + } +} + +/// Settings for mass matrix adaptation +#[derive(Clone, Copy, Debug)] +pub struct DiagAdaptExpSettings { + pub store_mass_matrix: bool, +} + +impl Default for DiagAdaptExpSettings { + fn default() -> Self { + Self { + store_mass_matrix: false, + } + } +} + +pub(crate) struct ExpWindowDiagAdapt { + dim: usize, + exp_variance_draw: RunningVariance, + exp_variance_grad: RunningVariance, + exp_variance_grad_bg: RunningVariance, + exp_variance_draw_bg: RunningVariance, + settings: DiagAdaptExpSettings, + _phantom: PhantomData, +} + +impl ExpWindowDiagAdapt { + fn update_estimators(&mut self, collector: &DrawGradCollector) { + if collector.is_good { + self.exp_variance_draw + .add_sample(collector.draw.iter().copied()); + self.exp_variance_grad + .add_sample(collector.grad.iter().copied()); + self.exp_variance_draw_bg + .add_sample(collector.draw.iter().copied()); + self.exp_variance_grad_bg + .add_sample(collector.grad.iter().copied()); + } + } + + fn switch(&mut self, collector: &DrawGradCollector) { + self.exp_variance_draw = std::mem::replace( + &mut self.exp_variance_draw_bg, + RunningVariance::new(self.dim), + ); + self.exp_variance_grad = std::mem::replace( + &mut self.exp_variance_grad_bg, + RunningVariance::new(self.dim), + ); + + self.update_estimators(collector); + } + + fn current_count(&self) -> u64 { + assert!(self.exp_variance_draw.count() == self.exp_variance_grad.count()); + self.exp_variance_draw.count() + } + + fn background_count(&self) -> u64 { + assert!(self.exp_variance_draw_bg.count() == self.exp_variance_grad_bg.count()); + self.exp_variance_draw_bg.count() + } + + fn update_potential(&self, potential: &mut EuclideanPotential) { + if self.current_count() < 3 { + return; + } + assert!(self.current_count() > 2); + potential.mass_matrix.update_diag( + izip!( + self.exp_variance_draw.current(), + self.exp_variance_grad.current(), + ) + .map(|(draw, grad)| { + let val = (draw / grad).sqrt().clamp(LOWER_LIMIT, UPPER_LIMIT); + if !val.is_finite() { + None + } else { + Some(val) + } + }), + ); + } +} + + +#[derive(Clone, Debug)] +pub struct ExpWindowDiagAdaptStats { + pub mass_matrix_inv: Option>, +} + +impl AsSampleStatVec for ExpWindowDiagAdaptStats { + fn add_to_vec(&self, vec: &mut Vec) { + vec.push(( + "mass_matrix_inv", + SampleStatValue::OptionArray(self.mass_matrix_inv.clone()), + )); + } +} + +impl AdaptStrategy for ExpWindowDiagAdapt { + type Potential = EuclideanPotential; + type Collector = DrawGradCollector; + type Stats = ExpWindowDiagAdaptStats; + type Options = DiagAdaptExpSettings; + + fn new(options: Self::Options, _num_tune: u64, dim: usize) -> Self { + Self { + dim, + exp_variance_draw: RunningVariance::new(dim), + exp_variance_grad: RunningVariance::new(dim), + exp_variance_draw_bg: RunningVariance::new(dim), + exp_variance_grad_bg: RunningVariance::new(dim), + settings: options, + _phantom: PhantomData::default(), + } + } + + fn init( + &mut self, + _options: &mut NutsOptions, + potential: &mut Self::Potential, + state: &::State, + ) { + self.exp_variance_draw.add_sample(state.q.iter().copied()); + self.exp_variance_draw_bg.add_sample(state.q.iter().copied()); + self.exp_variance_grad.add_sample(state.grad.iter().copied()); + self.exp_variance_grad_bg.add_sample(state.grad.iter().copied()); + + potential.mass_matrix.update_diag( + state.grad.iter().map(|&grad| { + Some((grad).abs().recip().clamp(LOWER_LIMIT, UPPER_LIMIT)) + }) + ); + + } + + fn adapt( + &mut self, + _options: &mut NutsOptions, + _potential: &mut Self::Potential, + _draw: u64, + _collector: &Self::Collector, + ) { + // Must be controlled from a different meta strategy + } + + fn new_collector(&self) -> Self::Collector { + DrawGradCollector::new(self.dim) + } + + fn current_stats( + &self, + _options: &NutsOptions, + potential: &Self::Potential, + _collector: &Self::Collector, + ) -> Self::Stats { + let diag = if self.settings.store_mass_matrix { + Some(potential.mass_matrix.variance.clone()) + } else { + None + }; + ExpWindowDiagAdaptStats { + mass_matrix_inv: diag, + } + } +} + + +pub(crate) struct GradDiagStrategy { + step_size: DualAverageStrategy, + mass_matrix: ExpWindowDiagAdapt, + options: GradDiagOptions, + num_tune: u64, + // The number of draws in the the early window + early_end: u64, + + // The first draw number for the final step size adaptation window + final_step_size_window: u64, +} + +#[derive(Debug, Clone, Copy)] +pub struct GradDiagOptions { + pub dual_average_options: DualAverageSettings, + pub mass_matrix_options: DiagAdaptExpSettings, + pub early_window: f64, + pub step_size_window: f64, + pub mass_matrix_switch_freq: u64, + pub early_mass_matrix_switch_freq: u64, +} + +impl Default for GradDiagOptions { + fn default() -> Self { + Self { + dual_average_options: DualAverageSettings::default(), + mass_matrix_options: DiagAdaptExpSettings::default(), + early_window: 0.3, + //step_size_window: 0.08, + //step_size_window: 0.15, + step_size_window: 0.2, + mass_matrix_switch_freq: 60, + early_mass_matrix_switch_freq: 10, + } + } +} + +impl AdaptStrategy for GradDiagStrategy { + type Potential = EuclideanPotential; + type Collector = CombinedCollector< + AcceptanceRateCollector< as Hamiltonian>::State>, + DrawGradCollector + >; + type Stats = CombinedStats; + type Options = GradDiagOptions; + + fn new(options: Self::Options, num_tune: u64, dim: usize) -> Self { + let num_tune_f = num_tune as f64; + let step_size_window = (options.step_size_window * num_tune_f) as u64; + let early_end = (options.early_window * num_tune_f) as u64; + let final_second_step_size = num_tune.saturating_sub(step_size_window); + + assert!(early_end < num_tune); + + Self { + step_size: DualAverageStrategy::new(options.dual_average_options, num_tune, dim), + mass_matrix: ExpWindowDiagAdapt::new(options.mass_matrix_options, num_tune, dim), + options, + num_tune, + early_end, + final_step_size_window: final_second_step_size, + } + } + + fn init( + &mut self, + options: &mut NutsOptions, + potential: &mut Self::Potential, + state: &::State, + ) { + self.step_size.init(options, potential, state); + self.mass_matrix.init(options, potential, state); + self.step_size.enable(); + } + + fn adapt( + &mut self, + options: &mut NutsOptions, + potential: &mut Self::Potential, + draw: u64, + collector: &Self::Collector, + ) { + if draw >= self.num_tune { + return; + } + + if draw < self.final_step_size_window { + let is_early = draw < self.early_end; + let switch_freq = if is_early { + self.options.early_mass_matrix_switch_freq + } else { + self.options.mass_matrix_switch_freq + }; + + if self.mass_matrix.background_count() >= switch_freq { + self.mass_matrix.switch(&collector.collector2); + } else { + self.mass_matrix.update_estimators(&collector.collector2); + } + self.mass_matrix.update_potential(potential); + self.step_size.adapt(options, potential, draw, &collector.collector1); + return; + } + + if draw == self.num_tune - 1 { + self.step_size.finalize(); + } + self.step_size.adapt(options, potential, draw, &collector.collector1); + } + + fn new_collector(&self) -> Self::Collector { + CombinedCollector { + collector1: self.step_size.new_collector(), + collector2: self.mass_matrix.new_collector(), + } + } + + fn current_stats( + &self, + options: &NutsOptions, + potential: &Self::Potential, + collector: &Self::Collector, + ) -> Self::Stats { + CombinedStats { + stats1: self + .step_size + .current_stats(options, potential, &collector.collector1), + stats2: self + .mass_matrix + .current_stats(options, potential, &collector.collector2), + } + } +} + + +#[derive(Debug, Clone)] +pub struct CombinedStats { + pub stats1: D1, + pub stats2: D2, +} + +impl AsSampleStatVec for CombinedStats { + fn add_to_vec(&self, vec: &mut Vec) { + self.stats1.add_to_vec(vec); + self.stats2.add_to_vec(vec); + } +} + +pub(crate) struct CombinedCollector { + collector1: C1, + collector2: C2, +} + +impl Collector for CombinedCollector +where + C1: Collector, + C2: Collector, +{ + type State = C1::State; + + fn register_leapfrog( + &mut self, + start: &Self::State, + end: &Self::State, + divergence_info: Option<&dyn crate::nuts::DivergenceInfo>, + ) { + self.collector1 + .register_leapfrog(start, end, divergence_info); + self.collector2 + .register_leapfrog(start, end, divergence_info); + } + + fn register_draw(&mut self, state: &Self::State, info: &crate::nuts::SampleInfo) { + self.collector1.register_draw(state, info); + self.collector2.register_draw(state, info); + } + + fn register_init(&mut self, state: &Self::State, options: &crate::nuts::NutsOptions) { + self.collector1.register_init(state, options); + self.collector2.register_init(state, options); + } +} + +#[cfg(test)] +pub mod test_logps { + use crate::{cpu_potential::CpuLogpFunc, nuts::LogpError}; + use thiserror::Error; + + #[derive(Clone)] + pub struct NormalLogp { + dim: usize, + mu: f64, + } + + impl NormalLogp { + pub(crate) fn new(dim: usize, mu: f64) -> NormalLogp { + NormalLogp { dim, mu } + } + } + + #[derive(Error, Debug)] + pub enum NormalLogpError {} + impl LogpError for NormalLogpError { + fn is_recoverable(&self) -> bool { + false + } + } + + impl CpuLogpFunc for NormalLogp { + type Err = NormalLogpError; + + fn dim(&self) -> usize { + self.dim + } + fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result { + let n = position.len(); + assert!(gradient.len() == n); + + let mut logp = 0f64; + for (p, g) in position.iter().zip(gradient.iter_mut()) { + let val = *p - self.mu; + logp -= val * val / 2.; + *g = -val; + } + Ok(logp) + } + } +} + +#[cfg(test)] +mod test { + use super::test_logps::NormalLogp; + use super::*; + use crate::nuts::{AdaptStrategy, Chain, NutsChain, NutsOptions}; + + #[test] + fn instanciate_adaptive_sampler() { + let ndim = 10; + let func = NormalLogp::new(ndim, 3.); + let num_tune = 100; + let options = GradDiagOptions::default(); + let strategy = GradDiagStrategy::new(options, num_tune, ndim); + + let mass_matrix = DiagMassMatrix::new(ndim); + let max_energy_error = 1000f64; + let step_size = 0.1f64; + + let potential = EuclideanPotential::new(func, mass_matrix, max_energy_error, step_size); + let options = NutsOptions { + maxdepth: 10u64, + store_gradient: true, + }; + + let rng = { + use rand::SeedableRng; + rand::rngs::StdRng::seed_from_u64(42) + }; + let chain = 0u64; + + let mut sampler = NutsChain::new(potential, strategy, options, rng, chain); + sampler.set_position(&vec![1.5f64; ndim]).unwrap(); + for _ in 0..200 { + sampler.draw().unwrap(); + } + } +} \ No newline at end of file diff --git a/example.jl b/example.jl new file mode 100644 index 00000000..f0679a2a --- /dev/null +++ b/example.jl @@ -0,0 +1,78 @@ +using AdvancedHMC, ForwardDiff +using LogDensityProblems +using LinearAlgebra +using Distributions + +# Define the target distribution using the `LogDensityProblem` interface +struct LogTargetDensity + dim::Int +end +LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2 # standard multivariate normal +LogDensityProblems.dimension(p::LogTargetDensity) = p.dim +LogDensityProblems.capabilities(::Type{LogTargetDensity}) = LogDensityProblems.LogDensityOrder{0}() + +# Choose parameter dimensionality and initial parameter value +D = 10; +initial_θ = rand(D); +ℓπ = LogTargetDensity(D) + +# Set the number of samples to draw and warmup iterations +n_samples, n_adapts = 2_000, 1_000 + +# Define a Hamiltonian system +metric = DiagEuclideanMetric(D) +hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff) + +# Define a leapfrog solver, with initial step size chosen heuristically +initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) +integrator = Leapfrog(initial_ϵ) + +# Define an HMC sampler, with the following components +# - multinomial sampling scheme, +# - generalised No-U-Turn criteria, and +# - windowed adaption for step-size and diagonal mass matrix +proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator) +adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)) + +# Run the sampler to draw samples from the specified Gaussian, where +# - `samples` will store the samples +# - `stats` will store diagnostic statistics for each sample +samples, stats = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; progress=true) + + +# Other Example +# Setup +using Distributions: Distributions +using Bijectors: Bijectors +using Random +struct LogDensityDistribution{D<:Distributions.Distribution} + dist::D +end + +LogDensityProblems.dimension(d::LogDensityDistribution) = length(d.dist) +function LogDensityProblems.logdensity(ld::LogDensityDistribution, y) + d = ld.dist + b = Bijectors.inverse(Bijectors.bijector(d)) + x, logjac = Bijectors.with_logabsdet_jacobian(b, y) + return logpdf(d, x) + logjac +end +LogDensityProblems.capabilities(::Type{<:LogDensityDistribution}) = LogDensityProblems.LogDensityOrder{0}() + +# Random variance +n_samples, n_adapts = 2_000, 1_000 +Random.seed!(1) +D = 10 +σ² = 1 .+ abs.(randn(D)) + +# Diagonal Gaussian +ℓπ = LogDensityDistribution(MvNormal(Diagonal(σ²))) +metric = DiagEuclideanMetric(D) +θ_init = rand(D) +h = Hamiltonian(metric, ℓπ, ForwardDiff) +κ = NUTS(Leapfrog(find_good_stepsize(h, θ_init))) +adaptor = StanHMCAdaptor( + MassMatrixAdaptor(metric), + StepSizeAdaptor(0.8, κ.τ.integrator) +) +samples, stats = sample(h, κ, θ_init, n_samples, adaptor, n_adapts; verbose=false) +# adaptor.pc.var ≈ σ² diff --git a/nutpie.jl b/nutpie.jl new file mode 100644 index 00000000..de5afa90 --- /dev/null +++ b/nutpie.jl @@ -0,0 +1,56 @@ +# Example for Nuts-rs / Nutpie Adaptor +using AdvancedHMC, ForwardDiff +using LogDensityProblems +using LinearAlgebra +using Distributions +using Plots +const A = AdvancedHMC + +# Define the target distribution using the `LogDensityProblem` interface +struct LogTargetDensity + dim::Int +end +LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2 # standard multivariate normal +LogDensityProblems.dimension(p::LogTargetDensity) = p.dim +LogDensityProblems.capabilities(::Type{LogTargetDensity}) = LogDensityProblems.LogDensityOrder{0}() + +# Choose parameter dimensionality and initial parameter value +D = 20; +initial_θ = rand(D); +ℓπ = LogTargetDensity(D) +n_samples, n_adapts = 2_000, 200 + +# DEFAULT +metric = DiagEuclideanMetric(D) +hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff) +initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) +integrator = Leapfrog(initial_ϵ) +proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator) +adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)) + +@time samples1, stats1 = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; progress=true); + +# NUTPIE +# https://github.com/pymc-devs/nuts-rs/blob/main/src/adapt_strategy.rs +metric = DiagEuclideanMetric(D) +hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff) +initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) +integrator = Leapfrog(initial_ϵ) +proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator) +pc = A.ExpWeightedWelfordVar(size(metric)) +adaptor = A.NutpieHMCAdaptor(pc, StepSizeAdaptor(0.8, integrator)) +@time samples2, stats2 = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; progress=true); + +# # Plots +# Plot of variance +get_component(samples, idx) = [samples[i][idx] for i in 1:length(samples)] +# get_component(samples, 3) + +# comparison +idx = 10 +plot(plot(get_component(samples1, idx), label="Default", color=palette(:default)[1]), + plot(get_component(samples2, idx), label="Nutpie/Nuts-rs", color=palette(:default)[2]), plot_title=title = "Comparison of component $idx", layout=(2, 1)) + +# Histogram +pl = histogram(get_component(samples1, idx), label="Default", fillstyle=:\, color=palette(:default)[1], alpha=0.5) +histogram!(pl, get_component(samples2, idx), label="Nutpie/Nuts-rs", alpha=0.5, title="Comparison of component $idx") \ No newline at end of file diff --git a/rewrite.jl b/rewrite.jl new file mode 100644 index 00000000..c330449a --- /dev/null +++ b/rewrite.jl @@ -0,0 +1,134 @@ +using StanHMCAdaptor + +mutable struct DiagAdaptExpSettings <: StanHMCAdaptorSettings + store_mass_matrix::Bool +end + +DiagAdaptExpSettings() = DiagAdaptExpSettings(false) + +mutable struct ExpWindowDiagAdapt{F} <: StanHMCAdaptor{F} + dim::Int + exp_variance_draw::RunningVariance + exp_variance_grad::RunningVariance + exp_variance_grad_bg::RunningVariance + exp_variance_draw_bg::RunningVariance + settings::DiagAdaptExpSettings + _phantom::Phantom{F} +end + +function ExpWindowDiagAdapt(dim::Int, settings::DiagAdaptExpSettings) + ExpWindowDiagAdapt(dim, RunningVariance(dim), RunningVariance(dim), RunningVariance(dim), RunningVariance(dim), settings, Phantom{F}()) +end + +function update!(adaptor::ExpWindowDiagAdapt, state::StanHMCAdaptorState, collector::DrawGradCollector) + if collector.is_good + for i in 1:adaptor.dim + adaptor.exp_variance_draw.add_sample(collector.draw[i]) + adaptor.exp_variance_grad.add_sample(collector.grad[i]) + adaptor.exp_variance_draw_bg.add_sample(collector.draw[i]) + adaptor.exp_variance_grad_bg.add_sample(collector.grad[i]) + end + end + if adaptor.exp_variance_draw.count() >= 3 + for i in 1:adaptor.dim + diag = (adaptor.exp_variance_draw.current()[i] / adaptor.exp_variance_grad.current()[i])^0.5 + diag = max(LOWER_LIMIT, min(UPPER_LIMIT, diag)) + if isfinite(diag) + state.mass_matrix.update_diag[i] = diag + end + end + end +end + +function initialize_state(adaptor::ExpWindowDiagAdapt) + return StanHMCAdaptorState() +end + +mutable struct ExpWindowDiagAdaptState <: StanHMCAdaptorState + mass_matrix_inv::Union{Nothing,Vector{Float64}} +end + +function create_adaptor_state(adaptor::ExpWindowDiagAdapt) + return ExpWindowDiagAdaptState(nothing) +end + +function sample_stats(adaptor::ExpWindowDiagAdapt, state::ExpWindowDiagAdaptState) + ExpWindowDiagAdaptState(state.mass_matrix_inv) +end + +# Grad + +mutable struct GradDiagStrategy{F} <: StanHMCAdaptor{F} + step_size::DualAverageStrategy{F,DiagMassMatrix} + mass_matrix::ExpWindowDiagAdapt{F} + options::GradDiagOptions + num_tune::UInt64 + early_end::UInt64 + final_step_size_window::UInt64 +end + +mutable struct GradDiagOptions + dual_average_options::DualAverageSettings + mass_matrix_options::DiagAdaptExpSettings + early_window::Float64 + step_size_window::Float64 + mass_matrix_switch_freq::UInt64 + early_mass_matrix_switch_freq::UInt64 +end + +GradDiagOptions() = GradDiagOptions(DualAverageSettings(), DiagAdaptExpSettings(), 0.3, 0.2, 60, 10) + +mutable struct GradDiagStats <: StanHMCAdaptorStats + step_size_stats::DualAverageStats + mass_matrix_stats::ExpWindowDiagAdaptStats +end + +function GradDiagStrategy(options::GradDiagOptions, num_tune::UInt64, dim::Int) + num_tune_f = convert(Float64, num_tune) + step_size_window = convert(UInt64, options.step_size_window + * + num_tune_f) + early_end = convert(UInt64, options.early_window * num_tune_f) + final_second_step_size = max(num_tune - convert(UInt64, step_size_window), 0) + + GradDiagStrategy(DualAverageStrategy(options.dual_average_options, num_tune, dim), + ExpWindowDiagAdapt(dim, options.mass_matrix_options), + options, + num_tune, + early_end, + final_second_step_size) +end + +function update!(adaptor::GradDiagStrategy, state::StanHMCAdaptorState, collector::DrawGradCollector) + if collector.is_good + step_size_stats = update!(adaptor.step_size, state, collector) + mass_matrix_stats = update!(adaptor.mass_matrix, state, collector) + end + if adaptor.draw >= adaptor.num_tune + return + end + if adaptor.draw < adaptor.final_step_size_window + is_early = adaptor.draw < adaptor.early_end + switch_freq = is_early ? adaptor.options.early_mass_matrix_switch_freq : adaptor.options.mass_matrix_switch_freq + if adaptor.mass_matrix.background_count() >= switch_freq + adaptor.mass_matrix.switch(collector) + end + end + if adaptor.draw >= adaptor.final_step_size_window + adaptor.mass_matrix.update_potential(potential) + end +end + +function initialize_state(adaptor::GradDiagStrategy) + return initialize_state(adaptor.mass_matrix) +end + +function create_adaptor_state(adaptor::GradDiagStrategy) + return create_adaptor_state(adaptor.mass_matrix) +end + +function sample_stats(adaptor::GradDiagStrategy, state::StanHMCAdaptorState) + step_size_stats = sample_stats(adaptor.step_size, state) + mass_matrix_stats = sample_stats(adaptor.mass_matrix, state) + GradDiagStats(step_size_stats, mass_matrix_stats) +end \ No newline at end of file diff --git a/src/adaptation/Adaptation.jl b/src/adaptation/Adaptation.jl index 8aad626a..52417753 100644 --- a/src/adaptation/Adaptation.jl +++ b/src/adaptation/Adaptation.jl @@ -5,7 +5,7 @@ using LinearAlgebra: LinearAlgebra using Statistics: Statistics using UnPack: @unpack, @pack! -using ..AdvancedHMC: DEBUG, AbstractScalarOrVec +using ..AdvancedHMC: DEBUG, AbstractScalarOrVec, PhasePoint abstract type AbstractAdaptor end function getM⁻¹ end @@ -28,9 +28,9 @@ export MassMatrixAdaptor, UnitMassMatrix, WelfordVar, WelfordCov ## TODO: generalise this to a list of adaptors ## -struct NaiveHMCAdaptor{M<:MassMatrixAdaptor, Tssa<:StepSizeAdaptor} <: AbstractAdaptor - pc :: M - ssa :: Tssa +struct NaiveHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAdaptor + pc::M + ssa::Tssa end Base.show(io::IO, a::NaiveHMCAdaptor) = print(io, "NaiveHMCAdaptor(pc=$(a.pc), ssa=$(a.ssa))") @@ -56,4 +56,11 @@ finalize!(aca::NaiveHMCAdaptor) = finalize!(aca.ssa) include("stan_adaptor.jl") export NaiveHMCAdaptor, StanHMCAdaptor +const LOWER_LIMIT::Float64 = 1e-10 +const UPPER_LIMIT::Float64 = 1e10 + +include("nutpie_adaptor.jl") +export NutpieHMCAdaptor, ExpWeightedWelfordVar + + end # module diff --git a/src/adaptation/massmatrix.jl b/src/adaptation/massmatrix.jl index 4aeb1a85..f9db55f0 100644 --- a/src/adaptation/massmatrix.jl +++ b/src/adaptation/massmatrix.jl @@ -54,15 +54,15 @@ function update!(ve::DiagMatrixEstimator) end # NOTE: this naive variance estimator is used only in testing -struct NaiveVar{T<:AbstractFloat, E<:AbstractVector{<:AbstractVecOrMat{T}}} <: DiagMatrixEstimator{T} - S :: E - NaiveVar(S::E) where {E} = new{eltype(eltype(E)), E}(S) +struct NaiveVar{T<:AbstractFloat,E<:AbstractVector{<:AbstractVecOrMat{T}}} <: DiagMatrixEstimator{T} + S::E + NaiveVar(S::E) where {E} = new{eltype(eltype(E)),E}(S) end NaiveVar{T}(sz::Tuple{Int}) where {T<:AbstractFloat} = NaiveVar(Vector{Vector{T}}()) NaiveVar{T}(sz::Tuple{Int,Int}) where {T<:AbstractFloat} = NaiveVar(Vector{Matrix{T}}()) -NaiveVar(sz::Union{Tuple{Int}, Tuple{Int,Int}}) = NaiveVar{Float64}(sz) +NaiveVar(sz::Union{Tuple{Int},Tuple{Int,Int}}) = NaiveVar{Float64}(sz) Base.push!(nv::NaiveVar, s::AbstractVecOrMat) = push!(nv.S, s) @@ -74,28 +74,28 @@ function get_estimation(nv::NaiveVar) end # Ref: https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/welford_var_estimator.hpp -mutable struct WelfordVar{T<:AbstractFloat, E<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T} - n :: Int - n_min :: Int - μ :: E - M :: E - δ :: E # cache for diff - var :: E # cache for variance +mutable struct WelfordVar{T<:AbstractFloat,E<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T} + n::Int + n_min::Int + μ::E + M::E + δ::E # cache for diff + var::E # cache for variance function WelfordVar(n::Int, n_min::Int, μ::E, M::E, δ::E, var::E) where {E} - return new{eltype(E), E}(n, n_min, μ, M, δ, var) + return new{eltype(E),E}(n, n_min, μ, M, δ, var) end end Base.show(io::IO, ::WelfordVar) = print(io, "WelfordVar") function WelfordVar{T}( - sz::Union{Tuple{Int}, Tuple{Int,Int}}; + sz::Union{Tuple{Int},Tuple{Int,Int}}; n_min::Int=10, var=ones(T, sz) ) where {T<:AbstractFloat} return WelfordVar(0, n_min, zeros(T, sz), zeros(T, sz), zeros(T, sz), var) end -WelfordVar(sz::Union{Tuple{Int}, Tuple{Int,Int}}; kwargs...) = WelfordVar{Float64}(sz; kwargs...) +WelfordVar(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...) = WelfordVar{Float64}(sz; kwargs...) function Base.resize!(wv::WelfordVar, θ::AbstractVecOrMat{T}) where {T<:AbstractFloat} if size(θ) != size(wv.var) @@ -130,6 +130,91 @@ function get_estimation(wv::WelfordVar{T}) where {T<:AbstractFloat} return n / ((n + 5) * (n - 1)) * M .+ ϵ * (5 / (n + 5)) end +# Rust implementation of NUTS used in Nutpie (comes from nuts-rs crate) +# Source: https://github.com/pymc-devs/nuts-rs/blob/main/src/adapt_strategy.rs + +mutable struct ExpWeightedWelfordVar{T<:AbstractFloat,E<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T} + exp_variance_draw::WelfordVar{T,E} + exp_variance_grad::WelfordVar{T,E} + exp_variance_draw_bg::WelfordVar{T,E} + exp_variance_grad_bg::WelfordVar{T,E} + function ExpWeightedWelfordVar(exp_variance_draw::WelfordVar{T,E}, exp_variance_grad::WelfordVar{T,E}, exp_variance_draw_bg::WelfordVar{T,E}, exp_variance_grad_bg::WelfordVar{T,E}) where {T,E} + return new{eltype(E),E}(exp_variance_draw, exp_variance_grad, exp_variance_draw_bg, exp_variance_grad_bg) + end +end + +Base.show(io::IO, ::ExpWeightedWelfordVar) = print(io, "ExpWeightedWelfordVar") + +function ExpWeightedWelfordVar{T}( + sz::Union{Tuple{Int},Tuple{Int,Int}}; + n_min::Int=4, var=ones(T, sz) +) where {T<:AbstractFloat} + # return ExpWeightedWelfordVar(0, n_min, zeros(T, sz), zeros(T, sz), zeros(T, sz), var) + return ExpWeightedWelfordVar(WelfordVar{T}(sz; n_min, var), WelfordVar{T}(sz; n_min, var), WelfordVar{T}(sz; n_min, var), WelfordVar{T}(sz; n_min, var)) +end + +ExpWeightedWelfordVar(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...) = ExpWeightedWelfordVar{Float64}(sz; kwargs...) + +function Base.resize!(wv::ExpWeightedWelfordVar, θ::AbstractVecOrMat{T}, g::AbstractVecOrMat{T}) where {T<:AbstractFloat} + @assert size(θ) == size(g) "Size of draw and grad must be the same." + resize!(wv.exp_variance_draw, θ) + resize!(wv.exp_variance_grad, g) + resize!(wv.exp_variance_draw_bg, θ) + resize!(wv.exp_variance_grad_bg, g) +end + +function reset!(wv::ExpWeightedWelfordVar{T}) where {T<:AbstractFloat} + reset!(wv.exp_variance_draw) + reset!(wv.exp_variance_grad) + reset!(wv.exp_variance_draw_bg) + reset!(wv.exp_variance_grad_bg) +end + +function Base.push!(wv::ExpWeightedWelfordVar, θ::AbstractVecOrMat{T}, g::AbstractVecOrMat{T}) where {T} + @assert size(θ) == size(g) "Size of draw and grad must be the same." + push!(wv.exp_variance_draw, θ) + push!(wv.exp_variance_grad, g) + push!(wv.exp_variance_draw_bg, θ) + push!(wv.exp_variance_grad_bg, g) +end + +# swap the background and foreground estimators for both _draw and _grad variance +# unlike the Rust implementation, we don't update the estimators inside of the switch as well (called separately) +function switch!(wv::ExpWeightedWelfordVar) + wv.exp_variance_draw = wv.exp_variance_draw_bg + reset!(wv.exp_variance_draw_bg) + wv.exp_variance_grad = wv.exp_variance_grad_bg + reset!(wv.exp_variance_grad_bg) +end +current_count(wv) = wv.exp_variance_draw.n +background_count(wv) = wv.exp_variance_draw_bg.n + +function adapt!( + adaptor::ExpWeightedWelfordVar, + θ::AbstractVecOrMat{<:AbstractFloat}, + α::AbstractScalarOrVec{<:AbstractFloat}, + g::AbstractVecOrMat{<:AbstractFloat}, + is_update::Bool=true +) + resize!(adaptor, θ, g) + push!(adaptor, θ, g) + is_update && update!(adaptor) +end + +# mimics: let val = (draw / grad).sqrt().clamp(LOWER_LIMIT, UPPER_LIMIT); +# TODO: handle NaN +function get_estimation(ad::ExpWeightedWelfordVar{T}) where {T<:AbstractFloat} + var_draw = get_estimation(ad.exp_variance_draw) + var_grad = get_estimation(ad.exp_variance_grad) + var = (var_draw ./ var_grad) .|> sqrt .|> x -> clamp(x, LOWER_LIMIT, UPPER_LIMIT) + # re-use the last estimate `var` if the current estimate is not valid + return all(isfinite.(var)) ? var : ad.exp_variance_draw.var +end +# reuse the `var` slot in `WelfordVar` to store the estimated variance of the draw (the "current" one) +function update!(ad::ExpWeightedWelfordVar) + current_count(ad) >= ad.exp_variance_draw.n_min && (ad.exp_variance_draw.var .= get_estimation(ad)) +end + ## Dense mass matrix adaptor abstract type DenseMatrixEstimator{T} <: MassMatrixAdaptor end @@ -143,14 +228,14 @@ function update!(ce::DenseMatrixEstimator) end # NOTE: This naive covariance estimator is used only in testing. -struct NaiveCov{F<:AbstractFloat, T<:AbstractVector{<:AbstractVector{F}}} <: DenseMatrixEstimator{T} - S :: T - NaiveCov(S::E) where {E} = new{eltype(eltype(E)), E}(S) +struct NaiveCov{F<:AbstractFloat,T<:AbstractVector{<:AbstractVector{F}}} <: DenseMatrixEstimator{T} + S::T + NaiveCov(S::E) where {E} = new{eltype(eltype(E)),E}(S) end NaiveCov{T}(sz::Tuple{Int}) where {T<:AbstractFloat} = NaiveCov(Vector{Vector{T}}()) -NaiveCov(sz::Union{Tuple{Int}, Tuple{Int,Int}}; kwargs...) = NaiveCov{Float64}(sz; kwargs...) +NaiveCov(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...) = NaiveCov{Float64}(sz; kwargs...) Base.push!(nc::NaiveCov, s::AbstractVector) = push!(nc.S, s) @@ -163,12 +248,12 @@ end # Ref: https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/welford_covar_estimator.hpp mutable struct WelfordCov{F<:AbstractFloat} <: DenseMatrixEstimator{F} - n :: Int - n_min :: Int - μ :: Vector{F} - M :: Matrix{F} - δ :: Vector{F} # cache for diff - cov :: Matrix{F} + n::Int + n_min::Int + μ::Vector{F} + M::Matrix{F} + δ::Vector{F} # cache for diff + cov::Matrix{F} end Base.show(io::IO, ::WelfordCov) = print(io, "WelfordCov") diff --git a/src/adaptation/nutpie_adaptor.jl b/src/adaptation/nutpie_adaptor.jl new file mode 100644 index 00000000..1b33d3c6 --- /dev/null +++ b/src/adaptation/nutpie_adaptor.jl @@ -0,0 +1,115 @@ +### Mutable states + +mutable struct NutpieHMCAdaptorState + i::Int + n_adapts::Int + # The number of draws in the the early window + early_end::Int + # The first draw number for the final step size adaptation window + final_step_size_window::Int + + function NutpieHMCAdaptorState(i, n_adapts, early_end, final_step_size_window) + @assert (early_end < n_adapts) "Early_end must be less than num_tune (provided $early_end and $n_adapts)" + return new(i, n_adapts, early_end, final_step_size_window) + end +end +function NutpieHMCAdaptorState() + return NutpieHMCAdaptorState(0, 1000, 300, 800) +end + +function initialize!(state::NutpieHMCAdaptorState, early_window_share::Float64, + final_step_size_window_share::Float64, + n_adapts::Int) + + early_end = ceil(UInt64, early_window_share * n_adapts) + step_size_window = ceil(UInt64, final_step_size_window_share * n_adapts) + final_step_size_window = max(n_adapts - step_size_window, 0) + 1 + + state.early_end = early_end + state.n_adapts = n_adapts + state.final_step_size_window = final_step_size_window +end + +# function Base.show(io::IO, state::NutpieHMCAdaptorState) +# print(io, "window($(state.window_start), $(state.window_end)), window_splits(" * string(join(state.window_splits, ", ")) * ")") +# end + +### Nutpie's adaptation +# Acknowledgement: ... +struct NutpieHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAdaptor + pc::M + ssa::Tssa + early_window_share::Float64 + final_step_size_window_share::Float64 + mass_matrix_switch_freq::Int + early_mass_matrix_switch_freq::Int + state::NutpieHMCAdaptorState +end +# Base.show(io::IO, a::NutpieHMCAdaptor) = +# print(io, "NutpieHMCAdaptor(\n pc=$(a.pc),\n ssa=$(a.ssa),\n init_buffer=$(a.init_buffer), term_buffer=$(a.term_buffer), window_size=$(a.window_size),\n state=$(a.state)\n)") + +function NutpieHMCAdaptor( + pc::ExpWeightedWelfordVar, + ssa::StepSizeAdaptor; + early_window_share::Float64=0.3, + final_step_size_window_share::Float64=0.2, + mass_matrix_switch_freq::Int=60, + early_mass_matrix_switch_freq::Int=10 +) + return NutpieHMCAdaptor(pc, ssa, early_window_share, final_step_size_window_share, mass_matrix_switch_freq, early_mass_matrix_switch_freq, NutpieHMCAdaptorState()) +end + +# !Q: Is mass_matrix a variance or an inverse of it? It should be inverse, but accumulators are directly variance? +# forward the method to the current draw +# it will then be forwarded to `var` property +getM⁻¹(ca::NutpieHMCAdaptor) = getM⁻¹(ca.pc.exp_variance_draw) +getϵ(ca::NutpieHMCAdaptor) = getϵ(ca.ssa) + +function initialize!(adaptor::NutpieHMCAdaptor, n_adapts::Int, z::PhasePoint) + initialize!(adaptor.state, adaptor.early_window_share, adaptor.final_step_size_window_share, n_adapts) + # !Q: Shall we initialize from the gradient? + # Nutpie initializes the variance estimate with reciprocal of the gradient + # Like: Some((grad).abs().recip().clamp(LOWER_LIMIT, UPPER_LIMIT)) + adaptor.pc.exp_variance_draw.var = (1 ./ abs.(z.ℓπ.gradient)) |> x -> clamp.(x, LOWER_LIMIT, UPPER_LIMIT) + return adaptor +end +finalize!(adaptor::NutpieHMCAdaptor) = finalize!(adaptor.ssa) + +# +is_in_first_step_size_window(ad::NutpieHMCAdaptor) = ad.state.i <= ad.state.final_step_size_window +is_in_early_window(ad) = ad.state.i <= ad.state.early_end +switch_freq(ad::NutpieHMCAdaptor) = is_in_early_window(ad) ? ad.early_mass_matrix_switch_freq : ad.mass_matrix_switch_freq +# +# Changes vs Rust implementation +# - step_size_adapt is at the top +# - several checks are handled in sampler (finalize adaption, do not adapt during normal sampling) +function adapt!( + ad::NutpieHMCAdaptor, + θ::AbstractVecOrMat{<:AbstractFloat}, + α::AbstractScalarOrVec{<:AbstractFloat}, + z::PhasePoint +) + ad.state.i += 1 + + adapt!(ad.ssa, θ, α) + + # TODO: do we resize twice? also in update? + # !Q: Why do we check resizing several times during iteration? (also in adapt!) + resize!(ad.pc, θ, z.ℓπ.gradient) # Resize pre-conditioner if necessary. + + # determine whether to update mass matrix + if is_in_first_step_size_window(ad) + + # Switch swaps the background (_bg) values for currect, and resets the background values + # Frequency of the switch depends on the phase + background_count(ad.pc) >= switch_freq(ad) && switch!(ad.pc) + + # TODO: implement a skipper for bad draws + # !Q: Why does it always update? (as per Nuts-rs/Nutpie) + adapt!(ad.pc, θ, α, z.ℓπ.gradient, true) + end +end +# Missing: collector checks on divergences or terminating at idx=0 // finite and None esimtaor +# adapt! estimator only if collected stuff is good (divergences) +# init for mass_matrix is grad.abs().recip.clamp(LOWER_LIMIT, UPPER_LIMIT) // init of ExpWindowDiagAdapt +# add checks taht exp_vairances are the same length? diff --git a/src/sampler.jl b/src/sampler.jl index 39e8e8eb..94fc9d29 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -2,7 +2,7 @@ update(h::Hamiltonian, ::AbstractAdaptor) = h function update( - h::Hamiltonian, adaptor::Union{MassMatrixAdaptor, NaiveHMCAdaptor, StanHMCAdaptor} + h::Hamiltonian, adaptor::Union{MassMatrixAdaptor,NaiveHMCAdaptor,StanHMCAdaptor,NutpieHMCAdaptor} ) metric = renew(h.metric, getM⁻¹(adaptor)) return @set h.metric = metric @@ -10,7 +10,7 @@ end update(τ::Trajectory, ::AbstractAdaptor) = τ function update( - τ::Trajectory, adaptor::Union{StepSizeAdaptor, NaiveHMCAdaptor, StanHMCAdaptor} + τ::Trajectory, adaptor::Union{StepSizeAdaptor,NaiveHMCAdaptor,StanHMCAdaptor} ) # FIXME: this does not support change type of `ϵ` (e.g. Float to Vector) integrator = update_nom_step_size(τ.integrator, getϵ(adaptor)) @@ -34,8 +34,8 @@ end ## Interface functions ## function sample_init( - rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}}, - h::Hamiltonian, + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + h::Hamiltonian, θ::AbstractVecOrMat{<:AbstractFloat} ) # Ensure h.metric has the same dim as θ. @@ -46,8 +46,8 @@ function sample_init( end function transition( - rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}}, - h::Hamiltonian, + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + h::Hamiltonian, κ::HMCKernel, z::PhasePoint, ) @@ -74,7 +74,8 @@ function Adaptation.adapt!( i::Int, n_adapts::Int, θ::AbstractVecOrMat{<:AbstractFloat}, - α::AbstractScalarOrVec{<:AbstractFloat} + α::AbstractScalarOrVec{<:AbstractFloat}, + z::PhasePoint # needs to be added for compat ) isadapted = false if i <= n_adapts @@ -88,6 +89,29 @@ function Adaptation.adapt!( return h, κ, isadapted end +# Nutpie adaptor requires access to gradients in the Hamiltonian +function Adaptation.adapt!( + h::Hamiltonian, + κ::AbstractMCMCKernel, + adaptor::NutpieHMCAdaptor, + i::Int, + n_adapts::Int, + θ::AbstractVecOrMat{<:AbstractFloat}, + α::AbstractScalarOrVec{<:AbstractFloat}, + z::PhasePoint +) + isadapted = false + if i <= n_adapts + i == 1 && Adaptation.initialize!(adaptor, n_adapts, z) + adapt!(adaptor, θ, α, z) + i == n_adapts && finalize!(adaptor) + h = update(h, adaptor) + κ = update(κ, adaptor) + isadapted = true + end + return h, κ, isadapted +end + """ Progress meter update with all trajectory stats, iteration number and metric shown. """ @@ -126,7 +150,7 @@ sample( drop_warmup=drop_warmup, verbose=verbose, progress=progress, - (pm_next!)=pm_next!, + (pm_next!)=pm_next! ) """ @@ -153,7 +177,7 @@ Sample `n_samples` samples using the proposal `κ` under Hamiltonian `h`. - `progress` controls whether to show the progress meter or not """ function sample( - rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}}, + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, h::Hamiltonian, κ::HMCKernel, θ::T, @@ -178,13 +202,13 @@ function sample( t = transition(rng, h, κ, t.z) # Adapt h and κ; what mutable is the adaptor tstat = stat(t) - h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate) + h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate, t.z) tstat = merge(tstat, (is_adapt=isadapted,)) # Update progress meter if progress # Do include current iteration and mass matrix pm_next!(pm, (iterations=i, tstat..., mass_matrix=h.metric)) - # Report finish of adapation + # Report finish of adapation elseif verbose && isadapted && i == n_adapts @info "Finished $n_adapts adapation steps" adaptor κ.τ.integrator h.metric end From b778cf1cede0f73d1733dd7118edd851f16c1754 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Sun, 12 Feb 2023 21:01:54 +0000 Subject: [PATCH 2/3] separated Nutpie into separate steps to benchmark individually --- src/adaptation/Adaptation.jl | 5 +- src/adaptation/massmatrix.jl | 31 +++-- src/adaptation/nutpie_adaptor.jl | 207 ++++++++++++++++++++++++++----- src/sampler.jl | 18 +-- 4 files changed, 210 insertions(+), 51 deletions(-) diff --git a/src/adaptation/Adaptation.jl b/src/adaptation/Adaptation.jl index 9ca5a902..873988eb 100644 --- a/src/adaptation/Adaptation.jl +++ b/src/adaptation/Adaptation.jl @@ -8,13 +8,14 @@ using UnPack: @unpack, @pack! using ..AdvancedHMC: DEBUG, AbstractScalarOrVec, PhasePoint abstract type AbstractAdaptor end +abstract type AbstractHMCAdaptorWithGradients <: AbstractAdaptor end function getM⁻¹ end function getϵ end function adapt! end function reset! end function initialize! end function finalize! end -export AbstractAdaptor, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹ +export AbstractAdaptor, AbstractHMCAdaptorWithGradients, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹ struct NoAdaptation <: AbstractAdaptor end export NoAdaptation @@ -61,7 +62,7 @@ const LOWER_LIMIT::Float64 = 1e-10 const UPPER_LIMIT::Float64 = 1e10 include("nutpie_adaptor.jl") -export NutpieHMCAdaptor, ExpWeightedWelfordVar +export NutpieHMCAdaptor, ExpWeightedWelfordVar, NutpieHMCAdaptorNoSwitch, NutpieHMCAdaptorNoGradInit,NutpieHMCAdaptorNoSwitchNoGradInit end # module diff --git a/src/adaptation/massmatrix.jl b/src/adaptation/massmatrix.jl index 20e7bf1f..f214eea7 100644 --- a/src/adaptation/massmatrix.jl +++ b/src/adaptation/massmatrix.jl @@ -146,6 +146,9 @@ mutable struct ExpWeightedWelfordVar{T<:AbstractFloat,E<:AbstractVecOrMat{T}} <: end end +# save the best estimate of the variance in the "current" WelfordVar +getM⁻¹(ve::ExpWeightedWelfordVar) = ve.exp_variance_draw.var + Base.show(io::IO, ::ExpWeightedWelfordVar) = print(io, "ExpWeightedWelfordVar") function ExpWeightedWelfordVar{T}( @@ -158,12 +161,12 @@ end ExpWeightedWelfordVar(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...) = ExpWeightedWelfordVar{Float64}(sz; kwargs...) -function Base.resize!(wv::ExpWeightedWelfordVar, θ::AbstractVecOrMat{T}, g::AbstractVecOrMat{T}) where {T<:AbstractFloat} - @assert size(θ) == size(g) "Size of draw and grad must be the same." +function Base.resize!(wv::ExpWeightedWelfordVar, θ::AbstractVecOrMat{T}, ∇logπ::AbstractVecOrMat{T}) where {T<:AbstractFloat} + @assert size(θ) == size(∇logπ) "Size of draw and grad must be the same." resize!(wv.exp_variance_draw, θ) - resize!(wv.exp_variance_grad, g) + resize!(wv.exp_variance_grad, ∇logπ) resize!(wv.exp_variance_draw_bg, θ) - resize!(wv.exp_variance_grad_bg, g) + resize!(wv.exp_variance_grad_bg, ∇logπ) end function reset!(wv::ExpWeightedWelfordVar{T}) where {T<:AbstractFloat} @@ -173,12 +176,12 @@ function reset!(wv::ExpWeightedWelfordVar{T}) where {T<:AbstractFloat} reset!(wv.exp_variance_grad_bg) end -function Base.push!(wv::ExpWeightedWelfordVar, θ::AbstractVecOrMat{T}, g::AbstractVecOrMat{T}) where {T} - @assert size(θ) == size(g) "Size of draw and grad must be the same." +function Base.push!(wv::ExpWeightedWelfordVar, θ::AbstractVecOrMat{T}, ∇logπ::AbstractVecOrMat{T}) where {T} + @assert size(θ) == size(∇logπ) "Size of draw and grad must be the same." push!(wv.exp_variance_draw, θ) - push!(wv.exp_variance_grad, g) + push!(wv.exp_variance_grad, ∇logπ) push!(wv.exp_variance_draw_bg, θ) - push!(wv.exp_variance_grad_bg, g) + push!(wv.exp_variance_grad_bg, ∇logπ) end # swap the background and foreground estimators for both _draw and _grad variance @@ -196,28 +199,30 @@ function adapt!( adaptor::ExpWeightedWelfordVar, θ::AbstractVecOrMat{<:AbstractFloat}, α::AbstractScalarOrVec{<:AbstractFloat}, - g::AbstractVecOrMat{<:AbstractFloat}, + ∇logπ::AbstractVecOrMat{<:AbstractFloat}, is_update::Bool=true ) - resize!(adaptor, θ, g) - push!(adaptor, θ, g) + resize!(adaptor, θ, ∇logπ) + push!(adaptor, θ, ∇logπ) is_update && update!(adaptor) end -# mimics: let val = (draw / grad).sqrt().clamp(LOWER_LIMIT, UPPER_LIMIT); # TODO: handle NaN function get_estimation(ad::ExpWeightedWelfordVar{T}) where {T<:AbstractFloat} var_draw = get_estimation(ad.exp_variance_draw) var_grad = get_estimation(ad.exp_variance_grad) + # mimics: let val = (draw / grad).sqrt().clamp(LOWER_LIMIT, UPPER_LIMIT); var = (var_draw ./ var_grad) .|> sqrt .|> x -> clamp(x, LOWER_LIMIT, UPPER_LIMIT) # re-use the last estimate `var` if the current estimate is not valid return all(isfinite.(var)) ? var : ad.exp_variance_draw.var end -# reuse the `var` slot in `WelfordVar` to store the estimated variance of the draw (the "current" one) +# reuse the `var` slot in the `exp_variance_draw` (which is `WelfordVar`) +# to store the estimated variance of the draw (the "current" / "foreground" one) function update!(ad::ExpWeightedWelfordVar) current_count(ad) >= ad.exp_variance_draw.n_min && (ad.exp_variance_draw.var .= get_estimation(ad)) end + ## Dense mass matrix adaptor abstract type DenseMatrixEstimator{T} <: MassMatrixAdaptor end diff --git a/src/adaptation/nutpie_adaptor.jl b/src/adaptation/nutpie_adaptor.jl index 1b33d3c6..225bc558 100644 --- a/src/adaptation/nutpie_adaptor.jl +++ b/src/adaptation/nutpie_adaptor.jl @@ -1,5 +1,11 @@ -### Mutable states +####################################3 +### General methods +# it will then be forwarded to `adaptor` (there it resides in `exp_variance_draw.var`) +getM⁻¹(ca::AbstractHMCAdaptorWithGradients) = getM⁻¹(ca.pc) +getϵ(ca::AbstractHMCAdaptorWithGradients) = getϵ(ca.ssa) +finalize!(adaptor::AbstractHMCAdaptorWithGradients) = finalize!(adaptor.ssa) +### Mutable states mutable struct NutpieHMCAdaptorState i::Int n_adapts::Int @@ -36,7 +42,7 @@ end ### Nutpie's adaptation # Acknowledgement: ... -struct NutpieHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAdaptor +struct NutpieHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractHMCAdaptorWithGradients pc::M ssa::Tssa early_window_share::Float64 @@ -59,57 +65,200 @@ function NutpieHMCAdaptor( return NutpieHMCAdaptor(pc, ssa, early_window_share, final_step_size_window_share, mass_matrix_switch_freq, early_mass_matrix_switch_freq, NutpieHMCAdaptorState()) end -# !Q: Is mass_matrix a variance or an inverse of it? It should be inverse, but accumulators are directly variance? -# forward the method to the current draw -# it will then be forwarded to `var` property -getM⁻¹(ca::NutpieHMCAdaptor) = getM⁻¹(ca.pc.exp_variance_draw) -getϵ(ca::NutpieHMCAdaptor) = getϵ(ca.ssa) - -function initialize!(adaptor::NutpieHMCAdaptor, n_adapts::Int, z::PhasePoint) +function initialize!(adaptor::NutpieHMCAdaptor, n_adapts::Int, ∇logπ::AbstractVecOrMat{<:AbstractFloat}) initialize!(adaptor.state, adaptor.early_window_share, adaptor.final_step_size_window_share, n_adapts) # !Q: Shall we initialize from the gradient? # Nutpie initializes the variance estimate with reciprocal of the gradient # Like: Some((grad).abs().recip().clamp(LOWER_LIMIT, UPPER_LIMIT)) - adaptor.pc.exp_variance_draw.var = (1 ./ abs.(z.ℓπ.gradient)) |> x -> clamp.(x, LOWER_LIMIT, UPPER_LIMIT) + # TODO: point to var more dynamically + adaptor.pc.exp_variance_draw.var = (1 ./ abs.(∇logπ)) |> x -> clamp.(x, LOWER_LIMIT, UPPER_LIMIT) return adaptor end -finalize!(adaptor::NutpieHMCAdaptor) = finalize!(adaptor.ssa) -# -is_in_first_step_size_window(ad::NutpieHMCAdaptor) = ad.state.i <= ad.state.final_step_size_window -is_in_early_window(ad) = ad.state.i <= ad.state.early_end -switch_freq(ad::NutpieHMCAdaptor) = is_in_early_window(ad) ? ad.early_mass_matrix_switch_freq : ad.mass_matrix_switch_freq +#################################### +## Special case: Skip the initiation of the mass matrix with gradient +struct NutpieHMCAdaptorNoGradInit{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractHMCAdaptorWithGradients + pc::M + ssa::Tssa + early_window_share::Float64 + final_step_size_window_share::Float64 + mass_matrix_switch_freq::Int + early_mass_matrix_switch_freq::Int + state::NutpieHMCAdaptorState +end +# Base.show(io::IO, a::NutpieHMCAdaptor) = +# print(io, "NutpieHMCAdaptor(\n pc=$(a.pc),\n ssa=$(a.ssa),\n init_buffer=$(a.init_buffer), term_buffer=$(a.term_buffer), window_size=$(a.window_size),\n state=$(a.state)\n)") + +function NutpieHMCAdaptorNoGradInit( + pc::ExpWeightedWelfordVar, + ssa::StepSizeAdaptor; + early_window_share::Float64=0.3, + final_step_size_window_share::Float64=0.2, + mass_matrix_switch_freq::Int=60, + early_mass_matrix_switch_freq::Int=10 +) + return NutpieHMCAdaptorNoGradInit(pc, ssa, early_window_share, final_step_size_window_share, mass_matrix_switch_freq, early_mass_matrix_switch_freq, NutpieHMCAdaptorState()) +end +function initialize!(adaptor::NutpieHMCAdaptorNoGradInit, n_adapts::Int, ∇logπ::AbstractVecOrMat{<:AbstractFloat}) + initialize!(adaptor.state, adaptor.early_window_share, adaptor.final_step_size_window_share, n_adapts) + return adaptor +end +#################################### +## Special case: No switching, use StanHMCAdaptor-like strategy (but keep var+gradients) +struct NutpieHMCAdaptorNoSwitch{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractHMCAdaptorWithGradients + pc::M + ssa::Tssa + init_buffer::Int + term_buffer::Int + window_size::Int + state::StanHMCAdaptorState +end + +function NutpieHMCAdaptorNoSwitch( + pc::ExpWeightedWelfordVar, + ssa::StepSizeAdaptor; + init_buffer::Int = 75, + term_buffer::Int = 50, + window_size::Int = 25, +) + return NutpieHMCAdaptorNoSwitch( + pc, + ssa, + init_buffer, + term_buffer, + window_size, + StanHMCAdaptorState(), + ) +end + +function initialize!(adaptor::NutpieHMCAdaptorNoSwitch, n_adapts::Int,∇logπ::AbstractVecOrMat{<:AbstractFloat}) + initialize!( + adaptor.state, + adaptor.init_buffer, + adaptor.term_buffer, + adaptor.window_size, + n_adapts, + ) + adaptor.pc.exp_variance_draw.var = (1 ./ abs.(∇logπ)) |> x -> clamp.(x, LOWER_LIMIT, UPPER_LIMIT) + return adaptor +end + +############################################ +## Special case: No switching, use StanHMCAdaptor-like strategy (but keep var+gradients) +## Both switching and grad init disabled +struct NutpieHMCAdaptorNoSwitchNoGradInit{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractHMCAdaptorWithGradients + pc::M + ssa::Tssa + init_buffer::Int + term_buffer::Int + window_size::Int + state::StanHMCAdaptorState +end + +function NutpieHMCAdaptorNoSwitchNoGradInit( + pc::ExpWeightedWelfordVar, + ssa::StepSizeAdaptor; + init_buffer::Int = 75, + term_buffer::Int = 50, + window_size::Int = 25, +) + return NutpieHMCAdaptorNoSwitchNoGradInit( + pc, + ssa, + init_buffer, + term_buffer, + window_size, + StanHMCAdaptorState(), + ) +end + +function initialize!(adaptor::NutpieHMCAdaptorNoSwitchNoGradInit, n_adapts::Int,∇logπ::AbstractVecOrMat{<:AbstractFloat}) + initialize!( + adaptor.state, + adaptor.init_buffer, + adaptor.term_buffer, + adaptor.window_size, + n_adapts, + ) + return adaptor +end + + +##################################### +# Adaptation: main case ala Nutpie # # Changes vs Rust implementation # - step_size_adapt is at the top -# - several checks are handled in sampler (finalize adaption, do not adapt during normal sampling) +# - several checks are handled in sampler (finalize adaptation, does not adapt during normal sampling) +# - switch and push/update are handled separately to mimic the StanHMCAdaptor +# +# Missing: +# - collector checks on divergences or terminating at idx=0 // finite and None esimtaor +# - adapt! estimator only if collected stuff is good (divergences) +# - init for mass_matrix is grad.abs().recip.clamp(LOWER_LIMIT, UPPER_LIMIT) // init of ExpWindowDiagAdapt +# +is_in_first_step_size_window(tp::AbstractHMCAdaptorWithGradients) = tp.state.i <= tp.state.final_step_size_window +is_in_early_window(tp::AbstractHMCAdaptorWithGradients) = tp.state.i <= tp.state.early_end +switch_freq(tp::AbstractHMCAdaptorWithGradients) = is_in_early_window(tp) ? tp.early_mass_matrix_switch_freq : tp.mass_matrix_switch_freq +# function adapt!( - ad::NutpieHMCAdaptor, + tp::Union{NutpieHMCAdaptor,NutpieHMCAdaptorNoGradInit}, θ::AbstractVecOrMat{<:AbstractFloat}, α::AbstractScalarOrVec{<:AbstractFloat}, - z::PhasePoint + ∇logπ::AbstractVecOrMat{<:AbstractFloat} ) - ad.state.i += 1 + tp.state.i += 1 - adapt!(ad.ssa, θ, α) + adapt!(tp.ssa, θ, α) # TODO: do we resize twice? also in update? # !Q: Why do we check resizing several times during iteration? (also in adapt!) - resize!(ad.pc, θ, z.ℓπ.gradient) # Resize pre-conditioner if necessary. + resize!(tp.pc, θ, ∇logπ) # Resize pre-conditioner if necessary. # determine whether to update mass matrix - if is_in_first_step_size_window(ad) + if is_in_first_step_size_window(tp) - # Switch swaps the background (_bg) values for currect, and resets the background values + # Switch swaps the background (_bg) values for current, and resets the background values # Frequency of the switch depends on the phase - background_count(ad.pc) >= switch_freq(ad) && switch!(ad.pc) + background_count(tp.pc) >= switch_freq(tp) && switch!(tp.pc) # TODO: implement a skipper for bad draws # !Q: Why does it always update? (as per Nuts-rs/Nutpie) - adapt!(ad.pc, θ, α, z.ℓπ.gradient, true) + adapt!(tp.pc, θ, α, ∇logπ, true) end end -# Missing: collector checks on divergences or terminating at idx=0 // finite and None esimtaor -# adapt! estimator only if collected stuff is good (divergences) -# init for mass_matrix is grad.abs().recip.clamp(LOWER_LIMIT, UPPER_LIMIT) // init of ExpWindowDiagAdapt -# add checks taht exp_vairances are the same length? + +##################################### +# Adaptation: No switching - ala StanHMCAdaptor +# +is_in_window(tp::Union{NutpieHMCAdaptorNoSwitch,NutpieHMCAdaptorNoSwitchNoGradInit}) = + tp.state.i >= tp.state.window_start && tp.state.i <= tp.state.window_end +is_window_end(tp::Union{NutpieHMCAdaptorNoSwitch,NutpieHMCAdaptorNoSwitchNoGradInit}) = tp.state.i in tp.state.window_splits +# +function adapt!( + tp::Union{NutpieHMCAdaptorNoSwitch,NutpieHMCAdaptorNoSwitchNoGradInit}, + θ::AbstractVecOrMat{<:AbstractFloat}, + α::AbstractScalarOrVec{<:AbstractFloat}, + ∇logπ::AbstractVecOrMat{<:AbstractFloat} +) + tp.state.i += 1 + + adapt!(tp.ssa, θ, α) + + resize!(tp.pc, θ, ∇logπ) # Resize pre-conditioner if necessary. + + # Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp + if is_in_window(tp) + # We accumlate stats from θ online and only trigger the update of M⁻¹ in the end of window. + is_update_M⁻¹ = is_window_end(tp) + adapt!(tp.pc, θ, α, ∇logπ, is_update_M⁻¹) + end + + if is_window_end(tp) + reset!(tp.ssa) + reset!(tp.pc) + end +end + + + diff --git a/src/sampler.jl b/src/sampler.jl index 2ed95755..b0bf1b43 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -75,9 +75,11 @@ function Adaptation.adapt!( adaptor::AbstractAdaptor, i::Int, n_adapts::Int, - θ::AbstractVecOrMat{<:AbstractFloat}, + z::PhasePoint, + # θ::AbstractVecOrMat{<:AbstractFloat}, α::AbstractScalarOrVec{<:AbstractFloat}, ) + θ = z.θ isadapted = false if i <= n_adapts i == 1 && Adaptation.initialize!(adaptor, n_adapts) @@ -94,17 +96,19 @@ end function Adaptation.adapt!( h::Hamiltonian, κ::AbstractMCMCKernel, - adaptor::NutpieHMCAdaptor, + adaptor::AbstractHMCAdaptorWithGradients, i::Int, n_adapts::Int, - θ::AbstractVecOrMat{<:AbstractFloat}, + z::PhasePoint, + # θ::AbstractVecOrMat{<:AbstractFloat}, α::AbstractScalarOrVec{<:AbstractFloat}, - z::PhasePoint ) + θ = z.θ + ∇logπ = z.ℓπ.gradient isadapted = false if i <= n_adapts - i == 1 && Adaptation.initialize!(adaptor, n_adapts, z) - adapt!(adaptor, θ, α, z) + i == 1 && Adaptation.initialize!(adaptor, n_adapts,∇logπ) + adapt!(adaptor, θ, α,∇logπ) i == n_adapts && finalize!(adaptor) h = update(h, adaptor) κ = update(κ, adaptor) @@ -208,7 +212,7 @@ function sample( # Adapt h and κ; what mutable is the adaptor tstat = stat(t) h, κ, isadapted = - adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate) + adapt!(h, κ, adaptor, i, n_adapts, t.z, tstat.acceptance_rate) if isadapted num_divergent_transitions_during_adaption += tstat.numerical_error else From 3acd182457b0e2ffb91a6d7d22c16694d0e2ff29 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Mon, 13 Feb 2023 09:07:05 +0000 Subject: [PATCH 3/3] remove irrelevant files --- adapt_strategy.rs | 572 ---------------------------------------------- example.jl | 78 ------- nutpie.jl | 56 ----- rewrite.jl | 134 ----------- 4 files changed, 840 deletions(-) delete mode 100644 adapt_strategy.rs delete mode 100644 example.jl delete mode 100644 nutpie.jl delete mode 100644 rewrite.jl diff --git a/adapt_strategy.rs b/adapt_strategy.rs deleted file mode 100644 index 1d694e5f..00000000 --- a/adapt_strategy.rs +++ /dev/null @@ -1,572 +0,0 @@ -use std::{fmt::Debug, marker::PhantomData}; - -use itertools::izip; - -use crate::{ - cpu_potential::{CpuLogpFunc, EuclideanPotential}, - mass_matrix::{ - DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance, - }, - nuts::{ - AdaptStrategy, AsSampleStatVec, Collector, Hamiltonian, NutsOptions, SampleStatItem, - SampleStatValue, - }, - stepsize::{AcceptanceRateCollector, DualAverage, DualAverageOptions}, -}; - -const LOWER_LIMIT: f64 = 1e-10f64; -const UPPER_LIMIT: f64 = 1e10f64; - -pub(crate) struct DualAverageStrategy { - step_size_adapt: DualAverage, - options: DualAverageSettings, - enabled: bool, - finalized: bool, - _phantom1: PhantomData, - _phantom2: PhantomData, -} - -impl DualAverageStrategy { - fn enable(&mut self) { - self.enabled = true; - } - - fn finalize(&mut self) { - self.finalized = true; - } -} - - -#[derive(Debug, Clone, Copy)] -pub struct DualAverageStats { - pub step_size_bar: f64, - pub mean_tree_accept: f64, - pub n_steps: u64, -} - -impl AsSampleStatVec for DualAverageStats { - fn add_to_vec(&self, vec: &mut Vec) { - vec.push(("step_size_bar", SampleStatValue::F64(self.step_size_bar))); - vec.push(( - "mean_tree_accept", - SampleStatValue::F64(self.mean_tree_accept), - )); - vec.push(("n_steps", SampleStatValue::U64(self.n_steps))); - } -} - -#[derive(Debug, Clone, Copy)] -pub struct DualAverageSettings { - pub target_accept: f64, - pub initial_step: f64, - pub params: DualAverageOptions, -} - -impl Default for DualAverageSettings { - fn default() -> Self { - Self { - target_accept: 0.8, - initial_step: 0.1, - params: DualAverageOptions::default(), - } - } -} - -impl AdaptStrategy for DualAverageStrategy { - type Potential = EuclideanPotential; - type Collector = AcceptanceRateCollector; - type Stats = DualAverageStats; - type Options = DualAverageSettings; - - fn new(options: Self::Options, _num_tune: u64, _dim: usize) -> Self { - Self { - options, - enabled: true, - step_size_adapt: DualAverage::new(options.params, options.initial_step), - finalized: false, - _phantom1: PhantomData::default(), - _phantom2: PhantomData::default(), - } - } - - fn init( - &mut self, - _options: &mut NutsOptions, - potential: &mut Self::Potential, - _state: &::State, - ) { - potential.step_size = self.options.initial_step; - } - - fn adapt( - &mut self, - _options: &mut NutsOptions, - potential: &mut Self::Potential, - _draw: u64, - collector: &Self::Collector, - ) { - if self.finalized { - self.step_size_adapt - .advance(collector.mean.current(), self.options.target_accept); - potential.step_size = self.step_size_adapt.current_step_size_adapted(); - return; - } - if !self.enabled { - return; - } - self.step_size_adapt - .advance(collector.mean.current(), self.options.target_accept); - potential.step_size = self.step_size_adapt.current_step_size() - } - - fn new_collector(&self) -> Self::Collector { - AcceptanceRateCollector::new() - } - - fn current_stats( - &self, - _options: &NutsOptions, - _potential: &Self::Potential, - collector: &Self::Collector, - ) -> Self::Stats { - DualAverageStats { - step_size_bar: self.step_size_adapt.current_step_size_adapted(), - mean_tree_accept: collector.mean.current(), - n_steps: collector.mean.count(), - } - } -} - -/// Settings for mass matrix adaptation -#[derive(Clone, Copy, Debug)] -pub struct DiagAdaptExpSettings { - pub store_mass_matrix: bool, -} - -impl Default for DiagAdaptExpSettings { - fn default() -> Self { - Self { - store_mass_matrix: false, - } - } -} - -pub(crate) struct ExpWindowDiagAdapt { - dim: usize, - exp_variance_draw: RunningVariance, - exp_variance_grad: RunningVariance, - exp_variance_grad_bg: RunningVariance, - exp_variance_draw_bg: RunningVariance, - settings: DiagAdaptExpSettings, - _phantom: PhantomData, -} - -impl ExpWindowDiagAdapt { - fn update_estimators(&mut self, collector: &DrawGradCollector) { - if collector.is_good { - self.exp_variance_draw - .add_sample(collector.draw.iter().copied()); - self.exp_variance_grad - .add_sample(collector.grad.iter().copied()); - self.exp_variance_draw_bg - .add_sample(collector.draw.iter().copied()); - self.exp_variance_grad_bg - .add_sample(collector.grad.iter().copied()); - } - } - - fn switch(&mut self, collector: &DrawGradCollector) { - self.exp_variance_draw = std::mem::replace( - &mut self.exp_variance_draw_bg, - RunningVariance::new(self.dim), - ); - self.exp_variance_grad = std::mem::replace( - &mut self.exp_variance_grad_bg, - RunningVariance::new(self.dim), - ); - - self.update_estimators(collector); - } - - fn current_count(&self) -> u64 { - assert!(self.exp_variance_draw.count() == self.exp_variance_grad.count()); - self.exp_variance_draw.count() - } - - fn background_count(&self) -> u64 { - assert!(self.exp_variance_draw_bg.count() == self.exp_variance_grad_bg.count()); - self.exp_variance_draw_bg.count() - } - - fn update_potential(&self, potential: &mut EuclideanPotential) { - if self.current_count() < 3 { - return; - } - assert!(self.current_count() > 2); - potential.mass_matrix.update_diag( - izip!( - self.exp_variance_draw.current(), - self.exp_variance_grad.current(), - ) - .map(|(draw, grad)| { - let val = (draw / grad).sqrt().clamp(LOWER_LIMIT, UPPER_LIMIT); - if !val.is_finite() { - None - } else { - Some(val) - } - }), - ); - } -} - - -#[derive(Clone, Debug)] -pub struct ExpWindowDiagAdaptStats { - pub mass_matrix_inv: Option>, -} - -impl AsSampleStatVec for ExpWindowDiagAdaptStats { - fn add_to_vec(&self, vec: &mut Vec) { - vec.push(( - "mass_matrix_inv", - SampleStatValue::OptionArray(self.mass_matrix_inv.clone()), - )); - } -} - -impl AdaptStrategy for ExpWindowDiagAdapt { - type Potential = EuclideanPotential; - type Collector = DrawGradCollector; - type Stats = ExpWindowDiagAdaptStats; - type Options = DiagAdaptExpSettings; - - fn new(options: Self::Options, _num_tune: u64, dim: usize) -> Self { - Self { - dim, - exp_variance_draw: RunningVariance::new(dim), - exp_variance_grad: RunningVariance::new(dim), - exp_variance_draw_bg: RunningVariance::new(dim), - exp_variance_grad_bg: RunningVariance::new(dim), - settings: options, - _phantom: PhantomData::default(), - } - } - - fn init( - &mut self, - _options: &mut NutsOptions, - potential: &mut Self::Potential, - state: &::State, - ) { - self.exp_variance_draw.add_sample(state.q.iter().copied()); - self.exp_variance_draw_bg.add_sample(state.q.iter().copied()); - self.exp_variance_grad.add_sample(state.grad.iter().copied()); - self.exp_variance_grad_bg.add_sample(state.grad.iter().copied()); - - potential.mass_matrix.update_diag( - state.grad.iter().map(|&grad| { - Some((grad).abs().recip().clamp(LOWER_LIMIT, UPPER_LIMIT)) - }) - ); - - } - - fn adapt( - &mut self, - _options: &mut NutsOptions, - _potential: &mut Self::Potential, - _draw: u64, - _collector: &Self::Collector, - ) { - // Must be controlled from a different meta strategy - } - - fn new_collector(&self) -> Self::Collector { - DrawGradCollector::new(self.dim) - } - - fn current_stats( - &self, - _options: &NutsOptions, - potential: &Self::Potential, - _collector: &Self::Collector, - ) -> Self::Stats { - let diag = if self.settings.store_mass_matrix { - Some(potential.mass_matrix.variance.clone()) - } else { - None - }; - ExpWindowDiagAdaptStats { - mass_matrix_inv: diag, - } - } -} - - -pub(crate) struct GradDiagStrategy { - step_size: DualAverageStrategy, - mass_matrix: ExpWindowDiagAdapt, - options: GradDiagOptions, - num_tune: u64, - // The number of draws in the the early window - early_end: u64, - - // The first draw number for the final step size adaptation window - final_step_size_window: u64, -} - -#[derive(Debug, Clone, Copy)] -pub struct GradDiagOptions { - pub dual_average_options: DualAverageSettings, - pub mass_matrix_options: DiagAdaptExpSettings, - pub early_window: f64, - pub step_size_window: f64, - pub mass_matrix_switch_freq: u64, - pub early_mass_matrix_switch_freq: u64, -} - -impl Default for GradDiagOptions { - fn default() -> Self { - Self { - dual_average_options: DualAverageSettings::default(), - mass_matrix_options: DiagAdaptExpSettings::default(), - early_window: 0.3, - //step_size_window: 0.08, - //step_size_window: 0.15, - step_size_window: 0.2, - mass_matrix_switch_freq: 60, - early_mass_matrix_switch_freq: 10, - } - } -} - -impl AdaptStrategy for GradDiagStrategy { - type Potential = EuclideanPotential; - type Collector = CombinedCollector< - AcceptanceRateCollector< as Hamiltonian>::State>, - DrawGradCollector - >; - type Stats = CombinedStats; - type Options = GradDiagOptions; - - fn new(options: Self::Options, num_tune: u64, dim: usize) -> Self { - let num_tune_f = num_tune as f64; - let step_size_window = (options.step_size_window * num_tune_f) as u64; - let early_end = (options.early_window * num_tune_f) as u64; - let final_second_step_size = num_tune.saturating_sub(step_size_window); - - assert!(early_end < num_tune); - - Self { - step_size: DualAverageStrategy::new(options.dual_average_options, num_tune, dim), - mass_matrix: ExpWindowDiagAdapt::new(options.mass_matrix_options, num_tune, dim), - options, - num_tune, - early_end, - final_step_size_window: final_second_step_size, - } - } - - fn init( - &mut self, - options: &mut NutsOptions, - potential: &mut Self::Potential, - state: &::State, - ) { - self.step_size.init(options, potential, state); - self.mass_matrix.init(options, potential, state); - self.step_size.enable(); - } - - fn adapt( - &mut self, - options: &mut NutsOptions, - potential: &mut Self::Potential, - draw: u64, - collector: &Self::Collector, - ) { - if draw >= self.num_tune { - return; - } - - if draw < self.final_step_size_window { - let is_early = draw < self.early_end; - let switch_freq = if is_early { - self.options.early_mass_matrix_switch_freq - } else { - self.options.mass_matrix_switch_freq - }; - - if self.mass_matrix.background_count() >= switch_freq { - self.mass_matrix.switch(&collector.collector2); - } else { - self.mass_matrix.update_estimators(&collector.collector2); - } - self.mass_matrix.update_potential(potential); - self.step_size.adapt(options, potential, draw, &collector.collector1); - return; - } - - if draw == self.num_tune - 1 { - self.step_size.finalize(); - } - self.step_size.adapt(options, potential, draw, &collector.collector1); - } - - fn new_collector(&self) -> Self::Collector { - CombinedCollector { - collector1: self.step_size.new_collector(), - collector2: self.mass_matrix.new_collector(), - } - } - - fn current_stats( - &self, - options: &NutsOptions, - potential: &Self::Potential, - collector: &Self::Collector, - ) -> Self::Stats { - CombinedStats { - stats1: self - .step_size - .current_stats(options, potential, &collector.collector1), - stats2: self - .mass_matrix - .current_stats(options, potential, &collector.collector2), - } - } -} - - -#[derive(Debug, Clone)] -pub struct CombinedStats { - pub stats1: D1, - pub stats2: D2, -} - -impl AsSampleStatVec for CombinedStats { - fn add_to_vec(&self, vec: &mut Vec) { - self.stats1.add_to_vec(vec); - self.stats2.add_to_vec(vec); - } -} - -pub(crate) struct CombinedCollector { - collector1: C1, - collector2: C2, -} - -impl Collector for CombinedCollector -where - C1: Collector, - C2: Collector, -{ - type State = C1::State; - - fn register_leapfrog( - &mut self, - start: &Self::State, - end: &Self::State, - divergence_info: Option<&dyn crate::nuts::DivergenceInfo>, - ) { - self.collector1 - .register_leapfrog(start, end, divergence_info); - self.collector2 - .register_leapfrog(start, end, divergence_info); - } - - fn register_draw(&mut self, state: &Self::State, info: &crate::nuts::SampleInfo) { - self.collector1.register_draw(state, info); - self.collector2.register_draw(state, info); - } - - fn register_init(&mut self, state: &Self::State, options: &crate::nuts::NutsOptions) { - self.collector1.register_init(state, options); - self.collector2.register_init(state, options); - } -} - -#[cfg(test)] -pub mod test_logps { - use crate::{cpu_potential::CpuLogpFunc, nuts::LogpError}; - use thiserror::Error; - - #[derive(Clone)] - pub struct NormalLogp { - dim: usize, - mu: f64, - } - - impl NormalLogp { - pub(crate) fn new(dim: usize, mu: f64) -> NormalLogp { - NormalLogp { dim, mu } - } - } - - #[derive(Error, Debug)] - pub enum NormalLogpError {} - impl LogpError for NormalLogpError { - fn is_recoverable(&self) -> bool { - false - } - } - - impl CpuLogpFunc for NormalLogp { - type Err = NormalLogpError; - - fn dim(&self) -> usize { - self.dim - } - fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result { - let n = position.len(); - assert!(gradient.len() == n); - - let mut logp = 0f64; - for (p, g) in position.iter().zip(gradient.iter_mut()) { - let val = *p - self.mu; - logp -= val * val / 2.; - *g = -val; - } - Ok(logp) - } - } -} - -#[cfg(test)] -mod test { - use super::test_logps::NormalLogp; - use super::*; - use crate::nuts::{AdaptStrategy, Chain, NutsChain, NutsOptions}; - - #[test] - fn instanciate_adaptive_sampler() { - let ndim = 10; - let func = NormalLogp::new(ndim, 3.); - let num_tune = 100; - let options = GradDiagOptions::default(); - let strategy = GradDiagStrategy::new(options, num_tune, ndim); - - let mass_matrix = DiagMassMatrix::new(ndim); - let max_energy_error = 1000f64; - let step_size = 0.1f64; - - let potential = EuclideanPotential::new(func, mass_matrix, max_energy_error, step_size); - let options = NutsOptions { - maxdepth: 10u64, - store_gradient: true, - }; - - let rng = { - use rand::SeedableRng; - rand::rngs::StdRng::seed_from_u64(42) - }; - let chain = 0u64; - - let mut sampler = NutsChain::new(potential, strategy, options, rng, chain); - sampler.set_position(&vec![1.5f64; ndim]).unwrap(); - for _ in 0..200 { - sampler.draw().unwrap(); - } - } -} \ No newline at end of file diff --git a/example.jl b/example.jl deleted file mode 100644 index f0679a2a..00000000 --- a/example.jl +++ /dev/null @@ -1,78 +0,0 @@ -using AdvancedHMC, ForwardDiff -using LogDensityProblems -using LinearAlgebra -using Distributions - -# Define the target distribution using the `LogDensityProblem` interface -struct LogTargetDensity - dim::Int -end -LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2 # standard multivariate normal -LogDensityProblems.dimension(p::LogTargetDensity) = p.dim -LogDensityProblems.capabilities(::Type{LogTargetDensity}) = LogDensityProblems.LogDensityOrder{0}() - -# Choose parameter dimensionality and initial parameter value -D = 10; -initial_θ = rand(D); -ℓπ = LogTargetDensity(D) - -# Set the number of samples to draw and warmup iterations -n_samples, n_adapts = 2_000, 1_000 - -# Define a Hamiltonian system -metric = DiagEuclideanMetric(D) -hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff) - -# Define a leapfrog solver, with initial step size chosen heuristically -initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) -integrator = Leapfrog(initial_ϵ) - -# Define an HMC sampler, with the following components -# - multinomial sampling scheme, -# - generalised No-U-Turn criteria, and -# - windowed adaption for step-size and diagonal mass matrix -proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator) -adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)) - -# Run the sampler to draw samples from the specified Gaussian, where -# - `samples` will store the samples -# - `stats` will store diagnostic statistics for each sample -samples, stats = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; progress=true) - - -# Other Example -# Setup -using Distributions: Distributions -using Bijectors: Bijectors -using Random -struct LogDensityDistribution{D<:Distributions.Distribution} - dist::D -end - -LogDensityProblems.dimension(d::LogDensityDistribution) = length(d.dist) -function LogDensityProblems.logdensity(ld::LogDensityDistribution, y) - d = ld.dist - b = Bijectors.inverse(Bijectors.bijector(d)) - x, logjac = Bijectors.with_logabsdet_jacobian(b, y) - return logpdf(d, x) + logjac -end -LogDensityProblems.capabilities(::Type{<:LogDensityDistribution}) = LogDensityProblems.LogDensityOrder{0}() - -# Random variance -n_samples, n_adapts = 2_000, 1_000 -Random.seed!(1) -D = 10 -σ² = 1 .+ abs.(randn(D)) - -# Diagonal Gaussian -ℓπ = LogDensityDistribution(MvNormal(Diagonal(σ²))) -metric = DiagEuclideanMetric(D) -θ_init = rand(D) -h = Hamiltonian(metric, ℓπ, ForwardDiff) -κ = NUTS(Leapfrog(find_good_stepsize(h, θ_init))) -adaptor = StanHMCAdaptor( - MassMatrixAdaptor(metric), - StepSizeAdaptor(0.8, κ.τ.integrator) -) -samples, stats = sample(h, κ, θ_init, n_samples, adaptor, n_adapts; verbose=false) -# adaptor.pc.var ≈ σ² diff --git a/nutpie.jl b/nutpie.jl deleted file mode 100644 index de5afa90..00000000 --- a/nutpie.jl +++ /dev/null @@ -1,56 +0,0 @@ -# Example for Nuts-rs / Nutpie Adaptor -using AdvancedHMC, ForwardDiff -using LogDensityProblems -using LinearAlgebra -using Distributions -using Plots -const A = AdvancedHMC - -# Define the target distribution using the `LogDensityProblem` interface -struct LogTargetDensity - dim::Int -end -LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2 # standard multivariate normal -LogDensityProblems.dimension(p::LogTargetDensity) = p.dim -LogDensityProblems.capabilities(::Type{LogTargetDensity}) = LogDensityProblems.LogDensityOrder{0}() - -# Choose parameter dimensionality and initial parameter value -D = 20; -initial_θ = rand(D); -ℓπ = LogTargetDensity(D) -n_samples, n_adapts = 2_000, 200 - -# DEFAULT -metric = DiagEuclideanMetric(D) -hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff) -initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) -integrator = Leapfrog(initial_ϵ) -proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator) -adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)) - -@time samples1, stats1 = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; progress=true); - -# NUTPIE -# https://github.com/pymc-devs/nuts-rs/blob/main/src/adapt_strategy.rs -metric = DiagEuclideanMetric(D) -hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff) -initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) -integrator = Leapfrog(initial_ϵ) -proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator) -pc = A.ExpWeightedWelfordVar(size(metric)) -adaptor = A.NutpieHMCAdaptor(pc, StepSizeAdaptor(0.8, integrator)) -@time samples2, stats2 = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; progress=true); - -# # Plots -# Plot of variance -get_component(samples, idx) = [samples[i][idx] for i in 1:length(samples)] -# get_component(samples, 3) - -# comparison -idx = 10 -plot(plot(get_component(samples1, idx), label="Default", color=palette(:default)[1]), - plot(get_component(samples2, idx), label="Nutpie/Nuts-rs", color=palette(:default)[2]), plot_title=title = "Comparison of component $idx", layout=(2, 1)) - -# Histogram -pl = histogram(get_component(samples1, idx), label="Default", fillstyle=:\, color=palette(:default)[1], alpha=0.5) -histogram!(pl, get_component(samples2, idx), label="Nutpie/Nuts-rs", alpha=0.5, title="Comparison of component $idx") \ No newline at end of file diff --git a/rewrite.jl b/rewrite.jl deleted file mode 100644 index c330449a..00000000 --- a/rewrite.jl +++ /dev/null @@ -1,134 +0,0 @@ -using StanHMCAdaptor - -mutable struct DiagAdaptExpSettings <: StanHMCAdaptorSettings - store_mass_matrix::Bool -end - -DiagAdaptExpSettings() = DiagAdaptExpSettings(false) - -mutable struct ExpWindowDiagAdapt{F} <: StanHMCAdaptor{F} - dim::Int - exp_variance_draw::RunningVariance - exp_variance_grad::RunningVariance - exp_variance_grad_bg::RunningVariance - exp_variance_draw_bg::RunningVariance - settings::DiagAdaptExpSettings - _phantom::Phantom{F} -end - -function ExpWindowDiagAdapt(dim::Int, settings::DiagAdaptExpSettings) - ExpWindowDiagAdapt(dim, RunningVariance(dim), RunningVariance(dim), RunningVariance(dim), RunningVariance(dim), settings, Phantom{F}()) -end - -function update!(adaptor::ExpWindowDiagAdapt, state::StanHMCAdaptorState, collector::DrawGradCollector) - if collector.is_good - for i in 1:adaptor.dim - adaptor.exp_variance_draw.add_sample(collector.draw[i]) - adaptor.exp_variance_grad.add_sample(collector.grad[i]) - adaptor.exp_variance_draw_bg.add_sample(collector.draw[i]) - adaptor.exp_variance_grad_bg.add_sample(collector.grad[i]) - end - end - if adaptor.exp_variance_draw.count() >= 3 - for i in 1:adaptor.dim - diag = (adaptor.exp_variance_draw.current()[i] / adaptor.exp_variance_grad.current()[i])^0.5 - diag = max(LOWER_LIMIT, min(UPPER_LIMIT, diag)) - if isfinite(diag) - state.mass_matrix.update_diag[i] = diag - end - end - end -end - -function initialize_state(adaptor::ExpWindowDiagAdapt) - return StanHMCAdaptorState() -end - -mutable struct ExpWindowDiagAdaptState <: StanHMCAdaptorState - mass_matrix_inv::Union{Nothing,Vector{Float64}} -end - -function create_adaptor_state(adaptor::ExpWindowDiagAdapt) - return ExpWindowDiagAdaptState(nothing) -end - -function sample_stats(adaptor::ExpWindowDiagAdapt, state::ExpWindowDiagAdaptState) - ExpWindowDiagAdaptState(state.mass_matrix_inv) -end - -# Grad - -mutable struct GradDiagStrategy{F} <: StanHMCAdaptor{F} - step_size::DualAverageStrategy{F,DiagMassMatrix} - mass_matrix::ExpWindowDiagAdapt{F} - options::GradDiagOptions - num_tune::UInt64 - early_end::UInt64 - final_step_size_window::UInt64 -end - -mutable struct GradDiagOptions - dual_average_options::DualAverageSettings - mass_matrix_options::DiagAdaptExpSettings - early_window::Float64 - step_size_window::Float64 - mass_matrix_switch_freq::UInt64 - early_mass_matrix_switch_freq::UInt64 -end - -GradDiagOptions() = GradDiagOptions(DualAverageSettings(), DiagAdaptExpSettings(), 0.3, 0.2, 60, 10) - -mutable struct GradDiagStats <: StanHMCAdaptorStats - step_size_stats::DualAverageStats - mass_matrix_stats::ExpWindowDiagAdaptStats -end - -function GradDiagStrategy(options::GradDiagOptions, num_tune::UInt64, dim::Int) - num_tune_f = convert(Float64, num_tune) - step_size_window = convert(UInt64, options.step_size_window - * - num_tune_f) - early_end = convert(UInt64, options.early_window * num_tune_f) - final_second_step_size = max(num_tune - convert(UInt64, step_size_window), 0) - - GradDiagStrategy(DualAverageStrategy(options.dual_average_options, num_tune, dim), - ExpWindowDiagAdapt(dim, options.mass_matrix_options), - options, - num_tune, - early_end, - final_second_step_size) -end - -function update!(adaptor::GradDiagStrategy, state::StanHMCAdaptorState, collector::DrawGradCollector) - if collector.is_good - step_size_stats = update!(adaptor.step_size, state, collector) - mass_matrix_stats = update!(adaptor.mass_matrix, state, collector) - end - if adaptor.draw >= adaptor.num_tune - return - end - if adaptor.draw < adaptor.final_step_size_window - is_early = adaptor.draw < adaptor.early_end - switch_freq = is_early ? adaptor.options.early_mass_matrix_switch_freq : adaptor.options.mass_matrix_switch_freq - if adaptor.mass_matrix.background_count() >= switch_freq - adaptor.mass_matrix.switch(collector) - end - end - if adaptor.draw >= adaptor.final_step_size_window - adaptor.mass_matrix.update_potential(potential) - end -end - -function initialize_state(adaptor::GradDiagStrategy) - return initialize_state(adaptor.mass_matrix) -end - -function create_adaptor_state(adaptor::GradDiagStrategy) - return create_adaptor_state(adaptor.mass_matrix) -end - -function sample_stats(adaptor::GradDiagStrategy, state::StanHMCAdaptorState) - step_size_stats = sample_stats(adaptor.step_size, state) - mass_matrix_stats = sample_stats(adaptor.mass_matrix, state) - GradDiagStats(step_size_stats, mass_matrix_stats) -end \ No newline at end of file