From 4b0be7fe5bbb3fd91f6bc342206b9adcc59f6bc1 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 22 Oct 2024 19:44:17 +0100 Subject: [PATCH] update `setparams!!` --- research/tests/runtests.jl | 2 +- src/abstractmcmc.jl | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/research/tests/runtests.jl b/research/tests/runtests.jl index 803458c6..2bb8e8d3 100644 --- a/research/tests/runtests.jl +++ b/research/tests/runtests.jl @@ -11,6 +11,6 @@ include("../src/riemannian_hmc.jl") include("relativistic_hmc.jl") include("riemannian_hmc.jl") -@main function runtests(patterns...; dry::Bool = false) +Comonicon.@main function runtests(patterns...; dry::Bool = false) retest(patterns...; dry = dry, verbose = Inf) end diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 1472a622..3a2ff638 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -35,8 +35,12 @@ function AbstractMCMC.getparams(state::HMCState) return state.transition.z.θ end -function AbstractMCMC.setparams!!(state::HMCState, θ) - return @set state.transition.z.θ = θ +function AbstractMCMC.setparams!!(state::HMCState, params) + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) + return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, params, state.transition.z.r; + ℓκ=state.transition.z.ℓκ + ) end """