diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 9403bffe..0e2cdb63 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -165,8 +165,8 @@ function AbstractMCMC.step( # Compute next transition and state. state = HMCState(0, t, metric, κ, adaptor) - # Take actual first step. - return AbstractMCMC.step(rng, model, spl, state; kwargs...) + # Return the initial transition and state. + return Transition(t.z, merge(stat(t), (is_adapt = false,))), state end function AbstractMCMC.step( @@ -260,10 +260,13 @@ function (cb::HMCProgressCallback)( κ = state.κ tstat = t.stat isadapted = tstat.is_adapt - if isadapted - cb.num_divergent_transitions_during_adaption[] += tstat.numerical_error - else - cb.num_divergent_transitions[] += tstat.numerical_error + # The initial transition will not much information beyond the `is_adapt` field. + if haskey(tstat, :numerical_error) + if isadapted + cb.num_divergent_transitions_during_adaption[] += tstat.numerical_error + else + cb.num_divergent_transitions[] += tstat.numerical_error + end end # Update progress meter