diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 8e9bffd2c..0701c89fc 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -174,7 +174,7 @@ function AbstractMCMC.sample_init!( model(rng, spl.state.vi, spl) elseif islinked(spl.state.vi, spl) && spl.selector.tag != :default invlink!(spl.state.vi, spl) - model(rng, spl.state.vi, spl) + model(rng, spl.state.vi, spl) end end @@ -289,7 +289,7 @@ No-U-Turn Sampler (NUTS) sampler. Usage: ```julia -NUTS() # Use default NUTS configuration. +NUTS() # Use default NUTS configuration. NUTS(1000, 0.65) # Use 1000 adaption steps, and target accept ratio 0.65. ``` @@ -299,7 +299,7 @@ Arguments: - `δ::Float64` : Target acceptance rate for dual averaging. - `max_depth::Int` : Maximum doubling tree depth. - `Δ_max::Float64` : Maximum divergence during doubling tree. -- `ϵ::Float64` : Inital step size; 0 means automatically searching using a heuristic procedure. +- `init_ϵ::Float64` : Inital step size; 0 means automatically searching using a heuristic procedure. """ mutable struct NUTS{AD, space, metricT <: AHMC.AbstractMetric} <: AdaptiveHamiltonian{AD} @@ -434,8 +434,8 @@ function AbstractMCMC.step!( # Adaptation if spl.alg isa AdaptiveHamiltonian - spl.state.h, spl.state.traj, isadapted = - AHMC.adapt!(spl.state.h, spl.state.traj, spl.state.adaptor, + spl.state.h, spl.state.traj, isadapted = + AHMC.adapt!(spl.state.h, spl.state.traj, spl.state.adaptor, spl.state.i, spl.alg.n_adapts, t.z.θ, t.stat.acceptance_rate) end