diff --git a/research/tests/runtests.jl b/research/tests/runtests.jl index 803458c6..2bb8e8d3 100644 --- a/research/tests/runtests.jl +++ b/research/tests/runtests.jl @@ -11,6 +11,6 @@ include("../src/riemannian_hmc.jl") include("relativistic_hmc.jl") include("riemannian_hmc.jl") -@main function runtests(patterns...; dry::Bool = false) +Comonicon.@main function runtests(patterns...; dry::Bool = false) retest(patterns...; dry = dry, verbose = Inf) end diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 1472a622..3a2ff638 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -35,8 +35,12 @@ function AbstractMCMC.getparams(state::HMCState) return state.transition.z.θ end -function AbstractMCMC.setparams!!(state::HMCState, θ) - return @set state.transition.z.θ = θ +function AbstractMCMC.setparams!!(state::HMCState, params) + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) + return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, params, state.transition.z.r; + ℓκ=state.transition.z.ℓκ + ) end """