From 65371ce4d5a241f123d00e413f02080b387370b4 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Fri, 5 Apr 2019 12:05:37 +0100 Subject: [PATCH] Fix a bug in `find_good_eps` (#744) * revert theta after find_good_eps * revert lj as well * change random seed * add random seed --- src/inference/support/hmc_core.jl | 4 +++- test/inference/hmcda.jl | 2 ++ test/utilities/io.jl | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/inference/support/hmc_core.jl b/src/inference/support/hmc_core.jl index 1e643b16d..9cc22fddc 100644 --- a/src/inference/support/hmc_core.jl +++ b/src/inference/support/hmc_core.jl @@ -232,8 +232,10 @@ function find_good_eps(model, spl::Sampler{T}, vi::VarInfo) where T momentum_sampler = gen_momentum_sampler(vi, spl) grad_func = gen_grad_func(vi, spl, model) H_func = gen_H_func() - θ = vi[spl] + θ, lj = vi[spl], vi.logp ϵ = _find_good_eps(θ, lj_func, grad_func, H_func, momentum_sampler) + vi[spl] = θ + setlogp!(vi, lj) @info "[Turing] found initial ϵ: $ϵ" return ϵ end diff --git a/test/inference/hmcda.jl b/test/inference/hmcda.jl index ad10460b1..4ec46a29a 100644 --- a/test/inference/hmcda.jl +++ b/test/inference/hmcda.jl @@ -4,6 +4,8 @@ using Turing: Sampler include("../test_utils/AllUtils.jl") @testset "hmcda.jl" begin + Random.seed!(1234) + @numerical_testset "hmcda inference" begin alg1 = HMCDA(3000, 1000, 0.65, 0.015) # alg2 = Gibbs(3000, HMCDA(1, 200, 0.65, 0.35, :m), HMC(1, 0.25, 3, :s)) diff --git a/test/utilities/io.jl b/test/utilities/io.jl index 076761bdd..40d8db930 100644 --- a/test/utilities/io.jl +++ b/test/utilities/io.jl @@ -4,7 +4,7 @@ include("../test_utils/AllUtils.jl") @testset "io.jl" begin @testset "chain save/resume" begin - Random.seed!(123) + Random.seed!(1234) alg1 = HMCDA(3000, 1000, 0.65, 0.15) alg2 = PG(20, 500)