Skip to content

Commit

Permalink
Warn user if we're struggling to find good init params (#1999)
Browse files Browse the repository at this point in the history
* added warning message in case we cant find good initialization point in a reasonable number of tries

* version bump

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* added test for HMC initial params warning

* Update test/inference/hmc.jl

Co-authored-by: David Widmann <[email protected]>

* fixed the warning test

* Update test/inference/hmc.jl

Co-authored-by: David Widmann <[email protected]>

* relax prior tests a bit

* further relaxation

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
torfjelde and devmotion authored Jun 13, 2023
1 parent 861ae37 commit 9f76d75
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.25.2"


[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Expand Down
7 changes: 7 additions & 0 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,20 @@ function DynamicPPL.initialstep(
# If no initial parameters are provided, resample until the log probability
# and its gradient are finite.
if init_params === nothing
init_attempt_count = 1
while !isfinite(z)
if init_attempt_count == 10
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `init_params` keyword"
end

# NOTE: This will sample in the unconstrained space.
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
theta = vi[spl]

hamiltonian = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)
z = AHMC.phasepoint(rng, theta, hamiltonian)

init_attempt_count += 1
end
end

Expand Down
19 changes: 18 additions & 1 deletion test/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,23 @@
alg = NUTS(1000, 0.8)
gdemo_default_prior = DynamicPPL.contextualize(gdemo_default, DynamicPPL.PriorContext())
chain = sample(gdemo_default_prior, alg, 10_000)
check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0], atol=0.3)
check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0], atol=0.45)
end

@turing_testset "warning for difficult init params" begin
attempt = 0
@model function demo_warn_init_params()
x ~ Normal()
if (attempt += 1) < 30
Turing.@addlogprob! -Inf
end
end

@test_logs (
:warn,
"failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `init_params` keyword",
) (:info,) match_mode=:any begin
sample(demo_warn_init_params(), NUTS(), 5)
end
end
end

0 comments on commit 9f76d75

Please sign in to comment.