Skip to content

Commit

Permalink
Add demo_assume_literal_observe + rename demo_assume_observe_literal …
Browse files Browse the repository at this point in the history
…-> demo_assume_multivariate_observe_literal
  • Loading branch information
penelopeysm committed Nov 29, 2024
1 parent a53e37f commit d173586
Showing 1 changed file with 40 additions and 10 deletions.
50 changes: 40 additions & 10 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,28 +323,30 @@ function varnames(model::Model{typeof(demo_assume_dot_observe)})
return [@varname(s), @varname(m)]
end

@model function demo_assume_observe_literal()
# `assume` and literal `observe`
@model function demo_assume_multivariate_observe_literal()
# multivariate `assume` and literal `observe`
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
m ~ MvNormal(zeros(2), Diagonal(s))
[1.5, 2.0] ~ MvNormal(m, Diagonal(s))

return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
end
function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m)
s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
m_dist = MvNormal(zeros(2), Diagonal(s))
return logpdf(s_dist, s) + logpdf(m_dist, m)
end
function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
function loglikelihood_true(
model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m
)
return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0])
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_assume_observe_literal)}, s, m
model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m
)
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_assume_observe_literal)})
function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)})
return [@varname(s), @varname(m)]
end

Expand Down Expand Up @@ -377,6 +379,30 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)})
return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])]
end

@model function demo_assume_literal_observe()
# univariate `assume` and literal `observe`
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
1.5 ~ Normal(m, sqrt(s))
2.0 ~ Normal(m, sqrt(s))

return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
end
function logprior_true(model::Model{typeof(demo_assume_literal_observe)}, s, m)
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
end
function loglikelihood_true(model::Model{typeof(demo_assume_literal_observe)}, s, m)
return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0)
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_assume_literal_observe)}, s, m
)
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_assume_literal_observe)})
return [@varname(s), @varname(m)]
end

@model function demo_assume_literal_dot_observe()
# `assume` and literal `dot_observe`
s ~ InverseGamma(2, 3)
Expand Down Expand Up @@ -575,7 +601,8 @@ const DemoModels = Union{
Model{typeof(demo_dot_assume_observe_index)},
Model{typeof(demo_assume_dot_observe)},
Model{typeof(demo_assume_literal_dot_observe)},
Model{typeof(demo_assume_observe_literal)},
Model{typeof(demo_assume_literal_observe)},
Model{typeof(demo_assume_multivariate_observe_literal)},
Model{typeof(demo_dot_assume_observe_index_literal)},
Model{typeof(demo_assume_submodel_observe_index_literal)},
Model{typeof(demo_dot_assume_observe_submodel)},
Expand All @@ -585,7 +612,9 @@ const DemoModels = Union{
}

const UnivariateAssumeDemoModels = Union{
Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)}
Model{typeof(demo_assume_dot_observe)},
Model{typeof(demo_assume_literal_dot_observe)}
Model{typeof(demo_assume_literal_observe)}
}
function posterior_mean(model::UnivariateAssumeDemoModels)
return (s=49 / 24, m=7 / 6)
Expand All @@ -609,7 +638,7 @@ const MultivariateAssumeDemoModels = Union{
Model{typeof(demo_assume_index_observe)},
Model{typeof(demo_assume_multivariate_observe)},
Model{typeof(demo_dot_assume_observe_index)},
Model{typeof(demo_assume_observe_literal)},
Model{typeof(demo_assume_multivariate_observe_literal)},
Model{typeof(demo_dot_assume_observe_index_literal)},
Model{typeof(demo_assume_submodel_observe_index_literal)},
Model{typeof(demo_dot_assume_observe_submodel)},
Expand Down Expand Up @@ -759,9 +788,10 @@ const DEMO_MODELS = (
demo_assume_multivariate_observe(),
demo_dot_assume_observe_index(),
demo_assume_dot_observe(),
demo_assume_observe_literal(),
demo_assume_multivariate_observe_literal(),
demo_dot_assume_observe_index_literal(),
demo_assume_literal_dot_observe(),
demo_assume_literal_observe(),
demo_assume_submodel_observe_index_literal(),
demo_dot_assume_observe_submodel(),
demo_dot_assume_dot_observe_matrix(),
Expand Down

0 comments on commit d173586

Please sign in to comment.