From 6f44f01a0e0c09932f4026fb03d03691cb82c3c1 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Wed, 15 Apr 2020 21:21:16 +0100 Subject: [PATCH] Improved particle filter error message. (#900) --- src/core/container.jl | 4 +++- test/inference/smc.jl | 30 ++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 test/inference/smc.jl diff --git a/src/core/container.jl b/src/core/container.jl index dd53de973..fbd6bfc21 100644 --- a/src/core/container.jl +++ b/src/core/container.jl @@ -133,7 +133,9 @@ function Libtask.consume(pc :: ParticleContainer) if num_done == n res = Val{:done} elseif num_done != 0 - error("[consume]: mis-aligned execution traces, num_particles= $(n), num_done=$(num_done).") + # The posterior for models with random number of observations is not well-defined. + error("[consume]: mis-aligned execution traces, num_particles= $(n), + num_done=$(num_done). Please make sure the number of observations is NOT random.") else # update incremental likelihoods z2 = logZ(pc) diff --git a/test/inference/smc.jl b/test/inference/smc.jl new file mode 100644 index 000000000..00810223f --- /dev/null +++ b/test/inference/smc.jl @@ -0,0 +1,30 @@ +using Turing, Random, Test +using StatsFuns + +dir = splitdir(splitdir(pathof(Turing))[1])[1] +include(dir*"/test/test_utils/AllUtils.jl") + +@turing_testset "smc.jl" begin + @model normal() = begin + a ~ Normal(4,5) + 3 ~ Normal(a,2) + b ~ Normal(a,1) + 1.5 ~ Normal(b,2) + a, b + end + + tested = sample(normal(), SMC(), 100); + + # failing test + @model fail_smc() = begin + a ~ Normal(4,5) + 3 ~ Normal(a,2) + b ~ Normal(a,1) + if a >= 4.0 + 1.5 ~ Normal(b,2) + end + a, b + end + + @test_throws ErrorException sample(fail_smc(), SMC(), 100) +end diff --git a/test/runtests.jl b/test/runtests.jl index 8ae240548..e5c5438b1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,6 +30,7 @@ include("test_utils/AllUtils.jl") include("inference/is.jl") include("inference/mh.jl") include("inference/ess.jl") + include("inference/smc.jl") include("inference/AdvancedSMC.jl") include("inference/Inference.jl") include("contrib/inference/dynamichmc.jl")