From d173586b41fa4d45f33627f6df3c6e585b6b439f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 29 Nov 2024 11:29:50 +0000 Subject: [PATCH] Add demo_assume_literal_observe + rename demo_assume_observe_literal -> demo_assume_multivariate_observe_literal --- src/test_utils/models.jl | 50 ++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 4aa2aaa42..e721a75ef 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -323,28 +323,30 @@ function varnames(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(s), @varname(m)] end -@model function demo_assume_observe_literal() - # `assume` and literal `observe` +@model function demo_assume_multivariate_observe_literal() + # multivariate `assume` and literal `observe` s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m ~ MvNormal(zeros(2), Diagonal(s)) [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) +function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m_dist = MvNormal(zeros(2), Diagonal(s)) return logpdf(s_dist, s) + logpdf(m_dist, m) end -function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) +function loglikelihood_true( + model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m +) return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0]) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_assume_observe_literal)}, s, m + model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_assume_observe_literal)}) +function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)}) return [@varname(s), @varname(m)] end @@ -377,6 +379,30 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end +@model function demo_assume_literal_observe() + # univariate `assume` and literal `observe` + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + 1.5 ~ Normal(m, sqrt(s)) + 2.0 ~ Normal(m, sqrt(s)) + + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_assume_literal_observe)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_literal_observe)}, s, m) + return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_literal_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_assume_literal_observe)}) + return [@varname(s), @varname(m)] +end + @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` s ~ InverseGamma(2, 3) @@ -575,7 +601,8 @@ const DemoModels = Union{ Model{typeof(demo_dot_assume_observe_index)}, Model{typeof(demo_assume_dot_observe)}, Model{typeof(demo_assume_literal_dot_observe)}, - Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_assume_literal_observe)}, + Model{typeof(demo_assume_multivariate_observe_literal)}, Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, @@ -585,7 +612,9 @@ const DemoModels = Union{ } const UnivariateAssumeDemoModels = Union{ - Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)} + Model{typeof(demo_assume_dot_observe)}, + Model{typeof(demo_assume_literal_dot_observe)} + Model{typeof(demo_assume_literal_observe)} } function posterior_mean(model::UnivariateAssumeDemoModels) return (s=49 / 24, m=7 / 6) @@ -609,7 +638,7 @@ const MultivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_index_observe)}, Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, - Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_assume_multivariate_observe_literal)}, Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, @@ -759,9 +788,10 @@ const DEMO_MODELS = ( demo_assume_multivariate_observe(), demo_dot_assume_observe_index(), demo_assume_dot_observe(), - demo_assume_observe_literal(), + demo_assume_multivariate_observe_literal(), demo_dot_assume_observe_index_literal(), demo_assume_literal_dot_observe(), + demo_assume_literal_observe(), demo_assume_submodel_observe_index_literal(), demo_dot_assume_observe_submodel(), demo_dot_assume_dot_observe_matrix(),