From 5b24cebe773922e0f3d5c4cb7f53162eb758b04d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 7 Nov 2024 14:52:55 +0000 Subject: [PATCH] Don't get stuck in an infinite loop if HMC can't find an initial point (#2392) * Error after 1000 attempts at finding initial parameters * Add a test * Fix missing import * Bump Project.toml --- src/mcmc/hmc.jl | 5 +++++ test/mcmc/hmc.jl | 11 ++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index d01ef274a..5887feb5e 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -181,6 +181,11 @@ function DynamicPPL.initialstep( if init_attempt_count == 10 @warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword" end + if init_attempt_count == 1000 + error( + "failed to find valid initial parameters in $(init_attempt_count) tries. This may indicate an error with the model or AD backend; please open an issue at https://github.com/TuringLang/Turing.jl/issues", + ) + end # NOTE: This will sample in the unconstrained space. vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform())) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 27c055394..27c928896 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -15,7 +15,7 @@ import Random using StableRNGs: StableRNG using StatsFuns: logistic import Mooncake -using Test: @test, @test_logs, @testset +using Test: @test, @test_logs, @testset, @test_throws using Turing @testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends @@ -272,6 +272,15 @@ using Turing end end + @testset "error for impossible model" begin + @model function demo_impossible() + x ~ Normal() + Turing.@addlogprob! -Inf + end + + @test_throws ErrorException sample(demo_impossible(), NUTS(; adtype=adbackend), 5) + end + @testset "(partially) issue: #2095" begin @model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV} xs = Vector{TV}(undef, 2)