diff --git a/Project.toml b/Project.toml index 253736676..c6d662c44 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" - -version = "0.24.4" +version = "0.24.5" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index fe2c0b3e5..bd7e8d8fb 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -363,7 +363,14 @@ Determine the default `eltype` of the values returned by `vi[spl]`. This method is considered legacy, and is likely to be deprecated in the future. """ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) - return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)})) + T = Base.promote_op(getindex, typeof(vi), typeof(spl)) + if T === Union{} + # In this case `getindex(vi, spl)` errors + # Let us throw a more descriptive error message + # Ref https://github.com/TuringLang/Turing.jl/issues/2151 + return eltype(vi[spl]) + end + return eltype(T) end # TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert diff --git a/src/varinfo.jl b/src/varinfo.jl index 83c914844..24316aed7 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1444,7 +1444,7 @@ function getindex(vi::TypedVarInfo, spl::Sampler) # Gets the ranges as a NamedTuple ranges = _getranges(vi, spl) # Calling getfield(ranges, f) gives all the indices in `vals` of the `vn`s with symbol `f` sampled by `spl` in `vi` - return vcat(_getindex(vi.metadata, ranges)...) + return reduce(vcat, _getindex(vi.metadata, ranges)) end # Recursively builds a tuple of the `vals` of all the symbols @generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 71c7c2263..285f5b0c9 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -2,10 +2,12 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] DynamicPPL = "0.24" +ReverseDiff = "1.15" Turing = "0.30" julia = "1.7" diff --git a/test/turing/runtests.jl b/test/turing/runtests.jl index 2c1d5085d..faadd1257 100644 --- a/test/turing/runtests.jl +++ b/test/turing/runtests.jl @@ -1,6 +1,7 @@ using DynamicPPL using Turing using LinearAlgebra +using ReverseDiff using Random using Test diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index 93195465a..c4e3fa87b 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -311,4 +311,35 @@ @test vi.metadata.w.gids[1] == Set([hmc.selector]) @test vi.metadata.u.gids[1] == Set([hmc.selector]) =# end + + @testset "Turing#2151: eltype(vi, spl)" begin + # build data + t = 1:0.05:8 + σ = 0.3 + y = @. rand(sin(t) + Normal(0, σ)) + + @model function state_space(y, TT, ::Type{T}=Float64) where {T} + # Priors + α ~ Normal(y[1], 0.001) + τ ~ Exponential(1) + η ~ filldist(Normal(0, 1), TT - 1) + σ ~ Exponential(1) + + # create latent variable + x = Vector{T}(undef, TT) + x[1] = α + for t in 2:TT + x[t] = x[t - 1] + η[t - 1] * τ + end + + # measurement model + y ~ MvNormal(x, σ^2 * I) + + return x + end + + n = 10 + model = state_space(y, length(t)) + @test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n + end end