Skip to content

Commit

Permalink
Fix a bug in find_good_eps (#744)
Browse files Browse the repository at this point in the history
* revert theta after find_good_eps

* revert lj as well

* change random seed

* add random seed
  • Loading branch information
xukai92 authored and yebai committed Apr 5, 2019
1 parent 3f72f12 commit 65371ce
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/inference/support/hmc_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,10 @@ function find_good_eps(model, spl::Sampler{T}, vi::VarInfo) where T
momentum_sampler = gen_momentum_sampler(vi, spl)
grad_func = gen_grad_func(vi, spl, model)
H_func = gen_H_func()
θ = vi[spl]
θ, lj = vi[spl], vi.logp
ϵ = _find_good_eps(θ, lj_func, grad_func, H_func, momentum_sampler)
vi[spl] = θ
setlogp!(vi, lj)
@info "[Turing] found initial ϵ: "
return ϵ
end
Expand Down
2 changes: 2 additions & 0 deletions test/inference/hmcda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using Turing: Sampler
include("../test_utils/AllUtils.jl")

@testset "hmcda.jl" begin
Random.seed!(1234)

@numerical_testset "hmcda inference" begin
alg1 = HMCDA(3000, 1000, 0.65, 0.015)
# alg2 = Gibbs(3000, HMCDA(1, 200, 0.65, 0.35, :m), HMC(1, 0.25, 3, :s))
Expand Down
2 changes: 1 addition & 1 deletion test/utilities/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ include("../test_utils/AllUtils.jl")

@testset "io.jl" begin
@testset "chain save/resume" begin
Random.seed!(123)
Random.seed!(1234)

alg1 = HMCDA(3000, 1000, 0.65, 0.15)
alg2 = PG(20, 500)
Expand Down

0 comments on commit 65371ce

Please sign in to comment.