From d6a00c822c2f3392b316df13d3c596037dc6168d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Oct 2023 17:03:52 +0100 Subject: [PATCH 1/3] concretize reshape in `reconstruct` for `LKJCholesky` to avoid type-instabilities --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 0135e4c24..9ba66e63e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -238,7 +238,7 @@ reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val) reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val) function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real}) - return reconstruct(dist, reshape(val, size(dist))) + return reconstruct(dist, Matrix(reshape(val, size(dist)))) end function reconstruct(dist::LKJCholesky, val::AbstractMatrix{<:Real}) return Cholesky(val, dist.uplo, 0) From 6b0c1e63d424a8fed5c5ad019728cacf8eca03c2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Oct 2023 17:46:56 +0100 Subject: [PATCH 2/3] added tests for `demo_lkjchol` --- test/model.jl | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/model.jl b/test/model.jl index fa7f5de47..566a292fd 100644 --- a/test/model.jl +++ b/test/model.jl @@ -25,6 +25,10 @@ function innermost_distribution_type(d::Distributions.Product) return dists[1] end +is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false +is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true +is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true + @testset "model.jl" begin @testset "convenience functions" begin model = gdemo_default # defined in test/test_util.jl @@ -329,4 +333,32 @@ end @test x_true.UL == result.x.UL end end + + @testset "Type stability of models" begin + models_to_test = [ + # FIXME: Fix issues with type-stability in `DEMO_MODELS`. + # DynamicPPL.TestUtils.DEMO_MODELS..., + DynamicPPL.TestUtils.demo_lkjchol(2), + ] + @testset "$(model.f)" for model in models_to_test + vns = DynamicPPL.TestUtils.varnames(model) + example_values = DynamicPPL.TestUtils.rand(model) + varinfos = filter( + is_typed_varinfo, + DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), + ) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @test (@inferred(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())); + true) + + varinfo_linked = DynamicPPL.link(varinfo, model) + @test ( + @inferred( + DynamicPPL.evaluate!!(model, varinfo_linked, DefaultContext()) + ); + true + ) + end + end + end end From 53e07cdd8bb142524739a6cbea6dbdc9c3af47b6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Oct 2023 17:49:52 +0100 Subject: [PATCH 3/3] bumped patch versionion --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e9c88fa9f..5657b153e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.19" +version = "0.23.20" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"