From ad848b2b4692f60ae2c0dd462625ec7eb1f6b303 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Jul 2024 18:18:13 +0200 Subject: [PATCH] Stricter types for evaluate!! methods (#629) --- src/model.jl | 9 +++++++-- test/model.jl | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/model.jl b/src/model.jl index 748f88ad5..a7c48017a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -909,13 +909,18 @@ function AbstractPPL.evaluate!!(model::Model, context::AbstractContext) return evaluate!!(model, VarInfo(), context) end -function AbstractPPL.evaluate!!(model::Model, args...) +function AbstractPPL.evaluate!!( + model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... +) return evaluate!!(model, Random.default_rng(), args...) end # without VarInfo function AbstractPPL.evaluate!!( - model::Model, rng::Random.AbstractRNG, sampler::AbstractSampler, args... + model::Model, + rng::Random.AbstractRNG, + sampler::AbstractSampler, + args::AbstractContext..., ) return evaluate!!(model, rng, VarInfo(), sampler, args...) end diff --git a/test/model.jl b/test/model.jl index f98dc03ab..c8fdf0202 100644 --- a/test/model.jl +++ b/test/model.jl @@ -396,4 +396,18 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end end + + @testset "Erroneous model call" begin + # Calling a model with the wrong arguments used to lead to infinite recursion, see + # https://github.com/TuringLang/Turing.jl/issues/2182. This guards against it. + @model function a_model(x) + m ~ Normal(0, 1) + x ~ Normal(m, 1) + return nothing + end + instance = a_model(1.0) + # `instance` should be called with rng, context, etc., but one may easily get + # confused and call it the way you are meant to call `a_model`. + @test_throws MethodError instance(1.0) + end end